diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 72c8dd60e63..6f5b0cc7615 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -10,6 +10,7 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal @@ -75,154 +76,6 @@ public virtual SqlExpression Translate([NotNull] Expression expression) return null; } - private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor - { - protected override Expression VisitExtension(Expression node) - { - Check.NotNull(node, nameof(node)); - - if (node is SqlExpression sqlExpression - && sqlExpression.TypeMapping == null) - { - throw new InvalidOperationException(CoreStrings.NullTypeMappingInSqlTree); - } - - return base.VisitExtension(node); - } - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - protected override Expression VisitMember(MemberExpression memberExpression) - { - Check.NotNull(memberExpression, nameof(memberExpression)); - - return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result) - ? result - : TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression) - ? null - : _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); - } - - private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression) - { - source = source.UnwrapTypeConversion(out var convertedType); - Expression visitedExpression; - switch (source) - { - case EntityShaperExpression entityShaperExpression: - visitedExpression = Visit(entityShaperExpression.ValueBufferExpression); - break; - - case MemberExpression memberExpression: - TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out visitedExpression); - break; - - case MethodCallExpression methodCallExpression - when methodCallExpression.TryGetEFPropertyArguments(out var innerSource, out var innerPropertyName): - TryBindMember(innerSource, MemberIdentity.Create(innerPropertyName), out visitedExpression); - break; - - case MethodCallExpression methodCallExpression - when methodCallExpression.TryGetIndexerArguments(_model, out var innerSource, out var innerPropertyName): - TryBindMember(innerSource, MemberIdentity.Create(innerPropertyName), out visitedExpression); - break; - - default: - visitedExpression = null; - break; - } - - if (visitedExpression is EntityProjectionExpression entityProjectionExpression) - { - convertedType ??= entityProjectionExpression.Type; - expression = member.MemberInfo != null - ? entityProjectionExpression.BindMember(member.MemberInfo, convertedType, clientEval: false, out _) - : entityProjectionExpression.BindMember(member.Name, convertedType, clientEval: false, out _); - - return expression != null; - } - - expression = null; - return false; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) - { - return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) - ? result - : null; - } - - // EF Indexer property - if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) - { - return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null; - } - - if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) - { - return null; - } - - var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; - for (var i = 0; i < arguments.Length; i++) - { - var argument = methodCallExpression.Arguments[i]; - if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) - { - return null; - } - - arguments[i] = sqlArgument; - } - - return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); - } - - private static Expression TryRemoveImplicitConvert(Expression expression) - { - if (expression is UnaryExpression unaryExpression - && (unaryExpression.NodeType == ExpressionType.Convert - || unaryExpression.NodeType == ExpressionType.ConvertChecked)) - { - var innerType = unaryExpression.Operand.Type.UnwrapNullableType(); - if (innerType.IsEnum) - { - innerType = Enum.GetUnderlyingType(innerType); - } - - var convertedType = unaryExpression.Type.UnwrapNullableType(); - - if (innerType == convertedType - || (convertedType == typeof(int) - && (innerType == typeof(byte) - || innerType == typeof(sbyte) - || innerType == typeof(char) - || innerType == typeof(short) - || innerType == typeof(ushort)))) - { - return TryRemoveImplicitConvert(unaryExpression.Operand); - } - } - - return expression; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -300,72 +153,52 @@ protected override Expression VisitConditional(ConditionalExpression conditional /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitUnary(UnaryExpression unaryExpression) - { - Check.NotNull(unaryExpression, nameof(unaryExpression)); - - var operand = Visit(unaryExpression.Operand); + protected override Expression VisitConstant(ConstantExpression constantExpression) + => new SqlConstantExpression(Check.NotNull(constantExpression, nameof(constantExpression)), null); - if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) - { - return null; - } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + Check.NotNull(extensionExpression, nameof(extensionExpression)); - switch (unaryExpression.NodeType) + switch (extensionExpression) { - case ExpressionType.Not: - return _sqlExpressionFactory.Not(sqlOperand); + case EntityProjectionExpression _: + case SqlExpression _: + return extensionExpression; - case ExpressionType.Negate: - return _sqlExpressionFactory.Negate(sqlOperand); + case EntityShaperExpression entityShaperExpression: + var result = Visit(entityShaperExpression.ValueBufferExpression); - case ExpressionType.Convert: - case ExpressionType.ConvertChecked: - // Object convert needs to be converted to explicit cast when mismatching types - if (operand.Type.IsInterface - && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) - || unaryExpression.Type.UnwrapNullableType() == operand.Type - || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) + if (result.NodeType == ExpressionType.Convert + && result.Type == typeof(ValueBuffer) + && result is UnaryExpression outerUnary + && outerUnary.Operand.NodeType == ExpressionType.Convert + && outerUnary.Operand.Type == typeof(object)) { - return sqlOperand; + result = ((UnaryExpression)outerUnary.Operand).Operand; } - break; - } - - return null; - } - - private SqlConstantExpression GetConstantOrNull(Expression expression) - { - if (CanEvaluate(expression)) - { - var value = Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(); - return new SqlConstantExpression(Expression.Constant(value, expression.Type), null); - } - - return null; - } - - private static bool CanEvaluate(Expression expression) - { -#pragma warning disable IDE0066 // Convert switch statement to expression - switch (expression) -#pragma warning restore IDE0066 // Convert switch statement to expression - { - case ConstantExpression constantExpression: - return true; + if (result is EntityProjectionExpression entityProjectionExpression) + { + return new EntityReferenceExpression(entityProjectionExpression); + } - case NewExpression newExpression: - return newExpression.Arguments.All(e => CanEvaluate(e)); + throw new InvalidOperationException("Randomization"); - case MemberInitExpression memberInitExpression: - return CanEvaluate(memberInitExpression.NewExpression) - && memberInitExpression.Bindings.All( - mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); + case ProjectionBindingExpression projectionBindingExpression: + return projectionBindingExpression.ProjectionMember != null + ? ((SelectExpression)projectionBindingExpression.QueryExpression) + .GetMappedProjection(projectionBindingExpression.ProjectionMember) + : null; default: - return false; + return null; } } @@ -375,12 +208,7 @@ private static bool CanEvaluate(Expression expression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitNew(NewExpression node) - { - Check.NotNull(node, nameof(node)); - - return GetConstantOrNull(node); - } + protected override Expression VisitInvocation(InvocationExpression invocationExpression) => null; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -388,12 +216,7 @@ protected override Expression VisitNew(NewExpression node) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitMemberInit(MemberInitExpression node) - { - Check.NotNull(node, nameof(node)); - - return GetConstantOrNull(node); - } + protected override Expression VisitLambda(Expression lambdaExpression) => null; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -401,12 +224,7 @@ protected override Expression VisitMemberInit(MemberInitExpression node) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitNewArray(NewArrayExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } + protected override Expression VisitListInit(ListInitExpression listInitExpression) => null; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -414,11 +232,16 @@ protected override Expression VisitNewArray(NewArrayExpression node) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitListInit(ListInitExpression node) + protected override Expression VisitMember(MemberExpression memberExpression) { - Check.NotNull(node, nameof(node)); + Check.NotNull(memberExpression, nameof(memberExpression)); - return null; + var innerExpression = Visit(memberExpression.Expression); + + return TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member)) + ?? (TranslationFailed(memberExpression.Expression, innerExpression, out var sqlInnerExpression) + ? null + : _memberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type)); } /// @@ -427,12 +250,8 @@ protected override Expression VisitListInit(ListInitExpression node) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitInvocation(InvocationExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } + protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + => GetConstantOrNull(Check.NotNull(memberInitExpression, nameof(memberInitExpression))); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -440,11 +259,34 @@ protected override Expression VisitInvocation(InvocationExpression node) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitLambda(Expression node) + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { - Check.NotNull(node, nameof(node)); + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - return null; + if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName) + || methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) + { + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName)); + } + + if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out var sqlObject)) + { + return null; + } + + var arguments = new SqlExpression[methodCallExpression.Arguments.Count]; + for (var i = 0; i < arguments.Length; i++) + { + var argument = methodCallExpression.Arguments[i]; + if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + { + return null; + } + + arguments[i] = sqlArgument; + } + + return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); } /// @@ -453,12 +295,16 @@ protected override Expression VisitLambda(Expression node) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitConstant(ConstantExpression constantExpression) - { - Check.NotNull(constantExpression, nameof(constantExpression)); + protected override Expression VisitNew(NewExpression newExpression) + => GetConstantOrNull(Check.NotNull(newExpression, nameof(newExpression))); - return new SqlConstantExpression(constantExpression, null); - } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) => null; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -467,11 +313,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override Expression VisitParameter(ParameterExpression parameterExpression) - { - Check.NotNull(parameterExpression, nameof(parameterExpression)); - - return new SqlParameterExpression(parameterExpression, null); - } + => new SqlParameterExpression(Check.NotNull(parameterExpression, nameof(parameterExpression)), null); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -479,27 +321,129 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected override Expression VisitExtension(Expression extensionExpression) + protected override Expression VisitUnary(UnaryExpression unaryExpression) { - Check.NotNull(extensionExpression, nameof(extensionExpression)); + Check.NotNull(unaryExpression, nameof(unaryExpression)); - switch (extensionExpression) + var operand = Visit(unaryExpression.Operand); + + if (operand is EntityReferenceExpression entityReferenceExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked + || unaryExpression.NodeType == ExpressionType.TypeAs)) { - case EntityProjectionExpression _: - case SqlExpression _: - return extensionExpression; + return entityReferenceExpression.Convert(unaryExpression.Type); + } - case EntityShaperExpression entityShaperExpression: - return Visit(entityShaperExpression.ValueBufferExpression); + if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) + { + return null; + } - case ProjectionBindingExpression projectionBindingExpression: - return projectionBindingExpression.ProjectionMember != null - ? ((SelectExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember) - : null; + switch (unaryExpression.NodeType) + { + case ExpressionType.Not: + return _sqlExpressionFactory.Not(sqlOperand); + + case ExpressionType.Negate: + return _sqlExpressionFactory.Negate(sqlOperand); + + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + // Object convert needs to be converted to explicit cast when mismatching types + if (operand.Type.IsInterface + && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + || unaryExpression.Type.UnwrapNullableType() == operand.Type + || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) + { + return sqlOperand; + } + + break; + } + + return null; + } + + private Expression TryBindMember(Expression source, MemberIdentity member) + { + if (!(source is EntityReferenceExpression entityReferenceExpression)) + { + return null; + } + + var result = member.MemberInfo != null + ? entityReferenceExpression.ParameterEntity.BindMember(member.MemberInfo, entityReferenceExpression.Type, clientEval: false, out _) + : entityReferenceExpression.ParameterEntity.BindMember(member.Name, entityReferenceExpression.Type, clientEval: false, out _); + + return result switch + { + EntityProjectionExpression entityProjectionExpression => new EntityReferenceExpression(entityProjectionExpression), + ObjectArrayProjectionExpression objectArrayProjectionExpression + => new EntityReferenceExpression(objectArrayProjectionExpression.InnerProjection), + _ => result + }; + } + + private static Expression TryRemoveImplicitConvert(Expression expression) + { + if (expression is UnaryExpression unaryExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked)) + { + var innerType = unaryExpression.Operand.Type.UnwrapNullableType(); + if (innerType.IsEnum) + { + innerType = Enum.GetUnderlyingType(innerType); + } + + var convertedType = unaryExpression.Type.UnwrapNullableType(); + + if (innerType == convertedType + || (convertedType == typeof(int) + && (innerType == typeof(byte) + || innerType == typeof(sbyte) + || innerType == typeof(char) + || innerType == typeof(short) + || innerType == typeof(ushort)))) + { + return TryRemoveImplicitConvert(unaryExpression.Operand); + } + } + + return expression; + } + + private SqlConstantExpression GetConstantOrNull(Expression expression) + { + if (CanEvaluate(expression)) + { + var value = Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(); + return new SqlConstantExpression(Expression.Constant(value, expression.Type), null); + } + + return null; + } + + private static bool CanEvaluate(Expression expression) + { +#pragma warning disable IDE0066 // Convert switch statement to expression + switch (expression) +#pragma warning restore IDE0066 // Convert switch statement to expression + { + case ConstantExpression constantExpression: + return true; + + case NewExpression newExpression: + return newExpression.Arguments.All(e => CanEvaluate(e)); + + case MemberInitExpression memberInitExpression: + return CanEvaluate(memberInitExpression.NewExpression) + && memberInitExpression.Bindings.All( + mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); default: - return null; + return false; } } @@ -516,5 +460,55 @@ private bool TranslationFailed(Expression original, Expression translation, out castTranslation = translation as SqlExpression; return false; } + + private sealed class EntityReferenceExpression : Expression + { + public EntityReferenceExpression(EntityProjectionExpression parameter) + { + ParameterEntity = parameter; + EntityType = parameter.EntityType; + Type = EntityType.ClrType; + } + + private EntityReferenceExpression(EntityProjectionExpression parameter, Type type) + { + ParameterEntity = parameter; + EntityType = parameter.EntityType; + Type = type; + } + + public EntityProjectionExpression ParameterEntity { get; } + public IEntityType EntityType { get; } + + public override Type Type { get; } + public override ExpressionType NodeType => ExpressionType.Extension; + + public Expression Convert(Type type) + { + if (type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface + { + return this; + } + + return new EntityReferenceExpression(ParameterEntity, type); + } + } + + private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor + { + protected override Expression VisitExtension(Expression extensionExpression) + { + Check.NotNull(extensionExpression, nameof(extensionExpression)); + + if (extensionExpression is SqlExpression sqlExpression + && sqlExpression.TypeMapping == null) + { + throw new InvalidOperationException(CoreStrings.NullTypeMappingInSqlTree); + } + + return base.VisitExtension(extensionExpression); + } + } } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index b1a31ce5550..6bf26ae4458 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -25,6 +25,12 @@ public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor { private const string _compiledQueryParameterPrefix = "__"; + private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; + + private static readonly MethodInfo _getParameterValueMethodInfo + = typeof(InMemoryExpressionTranslatingExpressionVisitor) + .GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue)); + private static readonly MethodInfo _likeMethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod( nameof(DbFunctionsExtensions.Like), @@ -52,7 +58,7 @@ private static string BuildEscapeRegexCharsPattern(IEnumerable regexSpecia => string.Join("|", regexSpecialChars.Select(c => @"\" + c)); private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; - private readonly EntityProjectionFindingExpressionVisitor _entityProjectionFindingExpressionVisitor; + private readonly EntityReferenceFindingExpressionVisitor _entityReferenceFindingExpressionVisitor; private readonly IModel _model; public InMemoryExpressionTranslatingExpressionVisitor( @@ -60,113 +66,15 @@ public InMemoryExpressionTranslatingExpressionVisitor( [NotNull] IModel model) { _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; - _entityProjectionFindingExpressionVisitor = new EntityProjectionFindingExpressionVisitor(); + _entityReferenceFindingExpressionVisitor = new EntityReferenceFindingExpressionVisitor(); _model = model; } - private sealed class EntityProjectionFindingExpressionVisitor : ExpressionVisitor - { - private bool _found; - - public bool Find(Expression expression) - { - _found = false; - - Visit(expression); - - return _found; - } - - public override Expression Visit(Expression expression) - { - if (_found) - { - return expression; - } - - if (expression is EntityProjectionExpression) - { - _found = true; - return expression; - } - - return base.Visit(expression); - } - } - - private sealed class PropertyFindingExpressionVisitor : ExpressionVisitor - { - private readonly IModel _model; - private IProperty _property; - - public PropertyFindingExpressionVisitor(IModel model) - { - _model = model; - } - - public IProperty Find(Expression expression) - { - Visit(expression); - - return _property; - } - - protected override Expression VisitMember(MemberExpression memberExpression) - { - var entityType = FindEntityType(memberExpression.Expression); - if (entityType != null) - { - _property = GetProperty(entityType, MemberIdentity.Create(memberExpression.Member)); - } - - return memberExpression; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName) - || methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) - { - var entityType = FindEntityType(source); - if (entityType != null) - { - _property = GetProperty(entityType, MemberIdentity.Create(propertyName)); - } - } - - return methodCallExpression; - } - - private static IProperty GetProperty(IEntityType entityType, MemberIdentity memberIdentity) - => memberIdentity.MemberInfo != null - ? entityType.FindProperty(memberIdentity.MemberInfo) - : entityType.FindProperty(memberIdentity.Name); - - private static IEntityType FindEntityType(Expression source) - { - source = source.UnwrapTypeConversion(out var convertedType); - - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - } - - return entityType; - } - - return null; - } - } - public virtual Expression Translate([NotNull] Expression expression) { var result = Visit(expression); - return _entityProjectionFindingExpressionVisitor.Find(result) + return _entityReferenceFindingExpressionVisitor.Find(result) ? null : result; } @@ -254,18 +162,36 @@ protected override Expression VisitConditional(ConditionalExpression conditional return Expression.Condition(test, ifTrue, ifFalse); } - protected override Expression VisitMember(MemberExpression memberExpression) + protected override Expression VisitExtension(Expression extensionExpression) { - Check.NotNull(memberExpression, nameof(memberExpression)); + Check.NotNull(extensionExpression, nameof(extensionExpression)); - if (TryBindMember( - memberExpression.Expression, - MemberIdentity.Create(memberExpression.Member), - memberExpression.Type, - out var result)) + switch (extensionExpression) { - return result; + case EntityProjectionExpression _: + return extensionExpression; + + case EntityShaperExpression entityShaperExpression: + return new EntityReferenceExpression(entityShaperExpression); + + case ProjectionBindingExpression projectionBindingExpression: + return projectionBindingExpression.ProjectionMember != null + ? ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression) + .GetMappedProjection(projectionBindingExpression.ProjectionMember) + : null; + + default: + return null; } + } + + protected override Expression VisitInvocation(InvocationExpression invocationExpression) => null; + protected override Expression VisitLambda(Expression lambdaExpression) => null; + protected override Expression VisitListInit(ListInitExpression listInitExpression) => null; + + protected override Expression VisitMember(MemberExpression memberExpression) + { + Check.NotNull(memberExpression, nameof(memberExpression)); var innerExpression = Visit(memberExpression.Expression); if (memberExpression.Expression != null @@ -274,6 +200,11 @@ protected override Expression VisitMember(MemberExpression memberExpression) return null; } + if (TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member), memberExpression.Type) is Expression result) + { + return result; + } + var updatedMemberExpression = (Expression)memberExpression.Update(innerExpression); if (innerExpression != null && innerExpression.Type.IsNullableType() @@ -295,104 +226,20 @@ static bool ShouldApplyNullProtectionForMemberAccess(Type callerType, string mem && (memberName == nameof(Nullable.Value) || memberName == nameof(Nullable.HasValue))); } - private bool TryBindMember(Expression source, MemberIdentity memberIdentity, Type type, out Expression result) - { - source = source.UnwrapTypeConversion(out var convertedType); - result = null; - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - if (entityType == null) - { - return false; - } - } - - var property = memberIdentity.MemberInfo != null - ? entityType.FindProperty(memberIdentity.MemberInfo) - : entityType.FindProperty(memberIdentity.Name); - if (property != null - && Visit(entityShaperExpression.ValueBufferExpression) is EntityProjectionExpression entityProjectionExpression - && (entityProjectionExpression.EntityType.IsAssignableFrom(property.DeclaringEntityType) - || property.DeclaringEntityType.IsAssignableFrom(entityProjectionExpression.EntityType))) - { - result = BindProperty(entityProjectionExpression, property); - - // if the result type change was just nullability change e.g from int to int? - // we want to preserve the new type for null propagation - if (result.Type != type - && !(result.Type.IsNullableType() - && !type.IsNullableType() - && result.Type.UnwrapNullableType() == type)) - { - result = Expression.Convert(result, type); - } - - return true; - } - } - - return false; - } - - private static bool IsConvertedToNullable(Expression result, Expression original) - => result.Type.IsNullableType() - && !original.Type.IsNullableType() - && result.Type.UnwrapNullableType() == original.Type; - - private static Expression ConvertToNullable(Expression expression) - => !expression.Type.IsNullableType() - ? Expression.Convert(expression, expression.Type.MakeNullable()) - : expression; - - private static Expression ConvertToNonNullable(Expression expression) - => expression.Type.IsNullableType() - ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) - : expression; - - private static Expression BindProperty(EntityProjectionExpression entityProjectionExpression, IProperty property) - => entityProjectionExpression.BindProperty(property); - - private static Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return groupByShaperExpression.ElementSelector; - } - - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - - private Expression GetPredicate(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) { - if (methodCallExpression.Arguments.Count == 1) + var expression = Visit(memberAssignment.Expression); + if (expression == null) { return null; } - if (methodCallExpression.Arguments.Count == 2) + if (IsConvertedToNullable(expression, memberAssignment.Expression)) { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); + expression = ConvertToNonNullable(expression); } - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + return memberAssignment.Update(expression); } protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) @@ -408,18 +255,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result)) - { - return result; - } - - throw new InvalidOperationException(CoreStrings.EFPropertyCalledWithWrongPropertyName); + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type) + ?? throw new InvalidOperationException(CoreStrings.QueryUnableToTranslateEFProperty(methodCallExpression.Print())); } // EF Indexer property if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) { - return TryBindMember(source, MemberIdentity.Create(propertyName), methodCallExpression.Type, out var result) ? result : null; + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName), methodCallExpression.Type); } // GroupBy Aggregate case @@ -436,7 +279,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Enumerable.Min): case nameof(Enumerable.Sum): { - var translation = Translate(GetSelector(methodCallExpression, groupByShaperExpression)); + var translation = Translate(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)); if (translation == null) { return null; @@ -468,7 +311,7 @@ MethodInfo GetMethod() case nameof(Enumerable.LongCount): { var countMethod = string.Equals(methodName, nameof(Enumerable.Count)); - var predicate = GetPredicate(methodCallExpression, groupByShaperExpression); + var predicate = GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression); if (predicate == null) { return Expression.Call( @@ -511,58 +354,22 @@ MethodInfo GetMethod() return null; } - subquery.ApplyProjection(); - if (subquery.Projection.Count != 1) - { - return null; - } - - // Unwrap ResultEnumerable - var selectMethod = (MethodCallExpression)subquery.ServerQueryExpression; - var resultEnumerable = (NewExpression)selectMethod.Arguments[0]; - var resultFunc = ((LambdaExpression)resultEnumerable.Arguments[0]).Body; - // New ValueBuffer construct - if (resultFunc is NewExpression newValueBufferExpression) + if (subqueryTranslation.ShaperExpression is EntityShaperExpression entityShaperExpression) { - Expression result; - var innerExpression = ((NewArrayExpression)newValueBufferExpression.Arguments[0]).Expressions[0]; - result = innerExpression is UnaryExpression unaryExpression - && innerExpression.NodeType == ExpressionType.Convert - && innerExpression.Type == typeof(object) - ? unaryExpression.Operand - : innerExpression; - - return result.Type == methodCallExpression.Type - ? result - : Expression.Convert(result, methodCallExpression.Type); + return new EntityReferenceExpression(subqueryTranslation); } - var selector = (LambdaExpression)selectMethod.Arguments[1]; - var readValueExpression = ((NewArrayExpression)((NewExpression)selector.Body).Arguments[0]).Expressions[0]; - if (readValueExpression is UnaryExpression unaryExpression2 - && unaryExpression2.NodeType == ExpressionType.Convert - && unaryExpression2.Type == typeof(object)) +#pragma warning disable IDE0046 // Convert to conditional expression + if (!(subqueryTranslation.ShaperExpression is ProjectionBindingExpression projectionBindingExpression)) +#pragma warning restore IDE0046 // Convert to conditional expression { - readValueExpression = unaryExpression2.Operand; + return null; } - var valueBufferVariable = Expression.Variable(typeof(ValueBuffer)); - var replacedReadExpression = ReplacingExpressionVisitor.Replace( - selector.Parameters[0], - valueBufferVariable, - readValueExpression); - - replacedReadExpression = replacedReadExpression.Type == methodCallExpression.Type - ? replacedReadExpression - : Expression.Convert(replacedReadExpression, methodCallExpression.Type); - - return Expression.Block( - variables: new[] { valueBufferVariable }, - Expression.Assign(valueBufferVariable, resultFunc), - Expression.Condition( - Expression.MakeMemberAccess(valueBufferVariable, _valueBufferIsEmpty), - Expression.Default(methodCallExpression.Type), - replacedReadExpression)); + return ProcessSingleResultScalar(subquery.ServerQueryExpression, + subquery.GetMappedProjection(projectionBindingExpression.ProjectionMember), + subquery.CurrentParameter, + methodCallExpression.Type); } if (methodCallExpression.Method == _likeMethodInfo @@ -636,48 +443,6 @@ MethodInfo GetMethod() return methodCallExpression.Update(@object, arguments); } - private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0]; - - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) - { - Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); - - if (typeBinaryExpression.NodeType == ExpressionType.TypeIs - && Visit(typeBinaryExpression.Expression) is EntityProjectionExpression entityProjectionExpression) - { - var entityType = entityProjectionExpression.EntityType; - - if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) - { - return Expression.Constant(true); - } - - var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); - if (derivedType != null) - { - var discriminatorProperty = entityType.GetDiscriminatorProperty(); - var boundProperty = BindProperty(entityProjectionExpression, discriminatorProperty); - - var equals = Expression.Equal( - boundProperty, - Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); - - foreach (var derivedDerivedType in derivedType.GetDerivedTypes()) - { - equals = Expression.OrElse( - equals, - Expression.Equal( - boundProperty, - Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); - } - - return equals; - } - } - - return Expression.Constant(false); - } - protected override Expression VisitNew(NewExpression newExpression) { Check.NotNull(newExpression, nameof(newExpression)); @@ -726,89 +491,61 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio return newArrayExpression.Update(newExpressions); } - protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment) + protected override Expression VisitParameter(ParameterExpression parameterExpression) { - var expression = Visit(memberAssignment.Expression); - if (expression == null) - { - return null; - } + Check.NotNull(parameterExpression, nameof(parameterExpression)); - if (IsConvertedToNullable(expression, memberAssignment.Expression)) + if (parameterExpression.Name.StartsWith(_compiledQueryParameterPrefix, StringComparison.Ordinal)) { - expression = ConvertToNonNullable(expression); + return Expression.Call( + _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterExpression.Name)); } - return memberAssignment.Update(expression); + throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print())); } - protected override Expression VisitExtension(Expression extensionExpression) + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) { - Check.NotNull(extensionExpression, nameof(extensionExpression)); + Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); - switch (extensionExpression) + if (typeBinaryExpression.NodeType == ExpressionType.TypeIs + && Visit(typeBinaryExpression.Expression) is EntityReferenceExpression entityReferenceExpression) { - case EntityProjectionExpression _: - return extensionExpression; + var entityType = entityReferenceExpression.EntityType; - case EntityShaperExpression entityShaperExpression: - return Visit(entityShaperExpression.ValueBufferExpression); - - case ProjectionBindingExpression projectionBindingExpression: - return projectionBindingExpression.ProjectionMember != null - ? ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember) - : null; - - default: - return null; - } - } - - protected override Expression VisitListInit(ListInitExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitInvocation(InvocationExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } + if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + { + return Expression.Constant(true); + } - protected override Expression VisitLambda(Expression node) - { - Check.NotNull(node, nameof(node)); + var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); + if (derivedType != null) + { + var discriminatorProperty = entityType.GetDiscriminatorProperty(); + var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType); - return null; - } + var equals = Expression.Equal( + boundProperty, + Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)); - protected override Expression VisitParameter(ParameterExpression parameterExpression) - { - Check.NotNull(parameterExpression, nameof(parameterExpression)); + foreach (var derivedDerivedType in derivedType.GetDerivedTypes()) + { + equals = Expression.OrElse( + equals, + Expression.Equal( + boundProperty, + Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType))); + } - if (parameterExpression.Name.StartsWith(_compiledQueryParameterPrefix, StringComparison.Ordinal)) - { - return Expression.Call( - _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(parameterExpression.Name)); + return equals; + } } - throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print())); + return Expression.Constant(false); } - private static readonly MethodInfo _getParameterValueMethodInfo - = typeof(InMemoryExpressionTranslatingExpressionVisitor) - .GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue)); - - [UsedImplicitly] - private static T GetParameterValue(QueryContext queryContext, string parameterName) - => (T)queryContext.ParameterValues[parameterName]; - protected override Expression VisitUnary(UnaryExpression unaryExpression) { Check.NotNull(unaryExpression, nameof(unaryExpression)); @@ -819,6 +556,14 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return null; } + if (newOperand is EntityReferenceExpression entityReferenceExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked + || unaryExpression.NodeType == ExpressionType.TypeAs)) + { + return entityReferenceExpression.Convert(unaryExpression.Type); + } + if (unaryExpression.NodeType == ExpressionType.Convert && newOperand.Type == unaryExpression.Type) { @@ -856,9 +601,150 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return result; } + private Expression TryBindMember(Expression source, MemberIdentity member, Type type) + { + if (!(source is EntityReferenceExpression entityReferenceExpression)) + { + return null; + } + + var entityType = entityReferenceExpression.EntityType; + + var property = member.MemberInfo != null + ? entityType.FindProperty(member.MemberInfo) + : entityType.FindProperty(member.Name); + + return property != null ? BindProperty(entityReferenceExpression, property, type) : null; + } + + private Expression BindProperty(EntityReferenceExpression entityReferenceExpression, IProperty property, Type type) + { + if (entityReferenceExpression.ParameterEntity != null) + { + var result = ((EntityProjectionExpression)Visit(entityReferenceExpression.ParameterEntity.ValueBufferExpression)) + .BindProperty(property); + + // if the result type change was just nullability change e.g from int to int? + // we want to preserve the new type for null propagation + if (result.Type != type + && !(result.Type.IsNullableType() + && !type.IsNullableType() + && result.Type.UnwrapNullableType() == type)) + { + result = Expression.Convert(result, type); + } + + return result; + } + + if (entityReferenceExpression.SubqueryEntity != null) + { + var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; + var readValueExpression = ((EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression)).BindProperty(property); + var inMemoryQueryExpression = (InMemoryQueryExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; + + return ProcessSingleResultScalar( + inMemoryQueryExpression.ServerQueryExpression, + readValueExpression, + inMemoryQueryExpression.CurrentParameter, + type); + } + + return null; + } + + private static Expression ProcessSingleResultScalar( + Expression serverQuery, Expression readValueExpression, Expression valueBufferParameter, Type type) + { + var singleResult = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; + if (readValueExpression is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert + && unaryExpression.Type == typeof(object)) + { + readValueExpression = unaryExpression.Operand; + } + + var valueBufferVariable = Expression.Variable(typeof(ValueBuffer)); + var replacedReadExpression = ReplacingExpressionVisitor.Replace( + valueBufferParameter, + valueBufferVariable, + readValueExpression); + + replacedReadExpression = replacedReadExpression.Type == type + ? replacedReadExpression + : Expression.Convert(replacedReadExpression, type); + + return Expression.Block( + variables: new[] { valueBufferVariable }, + Expression.Assign(valueBufferVariable, singleResult), + Expression.Condition( + Expression.MakeMemberAccess(valueBufferVariable, _valueBufferIsEmpty), + Expression.Default(type), + replacedReadExpression)); + } + + [UsedImplicitly] + private static T GetParameterValue(QueryContext queryContext, string parameterName) + => (T)queryContext.ParameterValues[parameterName]; + + private static bool IsConvertedToNullable(Expression result, Expression original) + => result.Type.IsNullableType() + && !original.Type.IsNullableType() + && result.Type.UnwrapNullableType() == original.Type; + + private static Expression ConvertToNullable(Expression expression) + => !expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.MakeNullable()) + : expression; + + private static Expression ConvertToNonNullable(Expression expression) + => expression.Type.IsNullableType() + ? Expression.Convert(expression, expression.Type.UnwrapNullableType()) + : expression; + + private static Expression GetSelectorOnGrouping( + MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return groupByShaperExpression.ElementSelector; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + + private Expression GetPredicateOnGrouping( + MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return null; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + [DebuggerStepThrough] - private bool TranslationFailed(Expression original, Expression translation) - => original != null && (translation == null || translation is EntityProjectionExpression); + private static bool TranslationFailed(Expression original, Expression translation) + => original != null && (translation == null || translation is EntityReferenceExpression); private static bool InMemoryLike(string matchExpression, string pattern, string escapeCharacter) { @@ -940,5 +826,145 @@ var regexPattern RegexOptions.IgnoreCase | RegexOptions.Singleline, _regexTimeout); } + + private sealed class EntityReferenceFindingExpressionVisitor : ExpressionVisitor + { + private bool _found; + + public bool Find(Expression expression) + { + _found = false; + + Visit(expression); + + return _found; + } + + public override Expression Visit(Expression expression) + { + if (_found) + { + return expression; + } + + if (expression is EntityReferenceExpression) + { + _found = true; + return expression; + } + + return base.Visit(expression); + } + } + + private sealed class PropertyFindingExpressionVisitor : ExpressionVisitor + { + private readonly IModel _model; + private IProperty _property; + + public PropertyFindingExpressionVisitor(IModel model) + { + _model = model; + } + + public IProperty Find(Expression expression) + { + Visit(expression); + + return _property; + } + + protected override Expression VisitMember(MemberExpression memberExpression) + { + var entityType = FindEntityType(memberExpression.Expression); + if (entityType != null) + { + _property = GetProperty(entityType, MemberIdentity.Create(memberExpression.Member)); + } + + return memberExpression; + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName) + || methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) + { + var entityType = FindEntityType(source); + if (entityType != null) + { + _property = GetProperty(entityType, MemberIdentity.Create(propertyName)); + } + } + + return methodCallExpression; + } + + private static IProperty GetProperty(IEntityType entityType, MemberIdentity memberIdentity) + => memberIdentity.MemberInfo != null + ? entityType.FindProperty(memberIdentity.MemberInfo) + : entityType.FindProperty(memberIdentity.Name); + + private static IEntityType FindEntityType(Expression source) + { + source = source.UnwrapTypeConversion(out var convertedType); + + if (source is EntityShaperExpression entityShaperExpression) + { + var entityType = entityShaperExpression.EntityType; + if (convertedType != null) + { + entityType = entityType.GetRootType().GetDerivedTypesInclusive() + .FirstOrDefault(et => et.ClrType == convertedType); + } + + return entityType; + } + + return null; + } + } + + private sealed class EntityReferenceExpression : Expression + { + public EntityReferenceExpression(EntityShaperExpression parameter) + { + ParameterEntity = parameter; + EntityType = parameter.EntityType; + } + + public EntityReferenceExpression(ShapedQueryExpression subquery) + { + SubqueryEntity = subquery; + EntityType = ((EntityShaperExpression)subquery.ShaperExpression).EntityType; + } + + private EntityReferenceExpression(EntityReferenceExpression entityReferenceExpression, IEntityType entityType) + { + ParameterEntity = entityReferenceExpression.ParameterEntity; + SubqueryEntity = entityReferenceExpression.SubqueryEntity; + EntityType = entityType; + } + + public EntityShaperExpression ParameterEntity { get; } + public ShapedQueryExpression SubqueryEntity { get; } + public IEntityType EntityType { get; } + + public override Type Type => EntityType.ClrType; + public override ExpressionType NodeType => ExpressionType.Extension; + + public Expression Convert(Type type) + { + if (type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface + { + return this; + } + + var derivedEntityType = EntityType.GetDerivedTypes().FirstOrDefault(et => et.ClrType == type); + + return derivedEntityType == null ? null : new EntityReferenceExpression(this, derivedEntityType); + } + } } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 4aa6ef53eb3..c9057094e8b 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -19,11 +19,10 @@ namespace Microsoft.EntityFrameworkCore.Query public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor { private readonly IModel _model; + private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlTypeMappingVerifyingExpressionVisitor; - protected virtual ISqlExpressionFactory SqlExpressionFactory { get; } - public RelationalSqlTranslatingExpressionVisitor( [NotNull] RelationalSqlTranslatingExpressionVisitorDependencies dependencies, [NotNull] QueryCompilationContext queryCompilationContext, @@ -34,7 +33,7 @@ public RelationalSqlTranslatingExpressionVisitor( Check.NotNull(queryableMethodTranslatingExpressionVisitor, nameof(queryableMethodTranslatingExpressionVisitor)); Dependencies = dependencies; - SqlExpressionFactory = dependencies.SqlExpressionFactory; + _sqlExpressionFactory = dependencies.SqlExpressionFactory; _model = queryCompilationContext.Model; _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; @@ -58,7 +57,7 @@ public virtual SqlExpression Translate([NotNull] Expression expression) translation = sqlUnaryExpression.Operand; } - translation = SqlExpressionFactory.ApplyDefaultTypeMapping(translation); + translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(translation); if (translation.TypeMapping == null) { @@ -92,13 +91,13 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression) if (inputType == typeof(int) || inputType == typeof(long)) { - sqlExpression = SqlExpressionFactory.ApplyDefaultTypeMapping( - SqlExpressionFactory.Convert(sqlExpression, typeof(double))); + sqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(sqlExpression, typeof(double))); } return inputType == typeof(float) - ? SqlExpressionFactory.Convert( - SqlExpressionFactory.Function( + ? _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( "AVG", new[] { sqlExpression }, nullable: true, @@ -106,7 +105,7 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression) typeof(double)), sqlExpression.Type, sqlExpression.TypeMapping) - : (SqlExpression)SqlExpressionFactory.Function( + : (SqlExpression)_sqlExpressionFactory.Function( "AVG", new[] { sqlExpression }, nullable: true, @@ -123,10 +122,10 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = return null; } - return SqlExpressionFactory.ApplyDefaultTypeMapping( - SqlExpressionFactory.Function( + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( "COUNT", - new[] { SqlExpressionFactory.Fragment("*") }, + new[] { _sqlExpressionFactory.Fragment("*") }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(int))); @@ -140,10 +139,10 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio return null; } - return SqlExpressionFactory.ApplyDefaultTypeMapping( - SqlExpressionFactory.Function( + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( "COUNT", - new[] { SqlExpressionFactory.Fragment("*") }, + new[] { _sqlExpressionFactory.Fragment("*") }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(long))); @@ -159,7 +158,7 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression) } return sqlExpression != null - ? SqlExpressionFactory.Function( + ? _sqlExpressionFactory.Function( "MAX", new[] { sqlExpression }, nullable: true, @@ -179,7 +178,7 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression) } return sqlExpression != null - ? SqlExpressionFactory.Function( + ? _sqlExpressionFactory.Function( "MIN", new[] { sqlExpression }, nullable: true, @@ -206,8 +205,8 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression) var inputType = sqlExpression.Type.UnwrapNullableType(); return inputType == typeof(float) - ? SqlExpressionFactory.Convert( - SqlExpressionFactory.Function( + ? _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( "SUM", new[] { sqlExpression }, nullable: true, @@ -215,7 +214,7 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression) typeof(double)), inputType, sqlExpression.TypeMapping) - : (SqlExpression)SqlExpressionFactory.Function( + : (SqlExpression)_sqlExpressionFactory.Function( "SUM", new[] { sqlExpression }, nullable: true, @@ -224,144 +223,100 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression) sqlExpression.TypeMapping); } - private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor + protected override Expression VisitBinary(BinaryExpression binaryExpression) { - protected override Expression VisitExtension(Expression node) - { - Check.NotNull(node, nameof(node)); - - if (node is SqlExpression sqlExpression - && !(node is SqlFragmentExpression) - && !(node is SqlFunctionExpression sqlFunctionExpression - && sqlFunctionExpression.Type.IsQueryableType())) - { - if (sqlExpression.TypeMapping == null) - { - throw new InvalidOperationException(CoreStrings.NullTypeMappingInSqlTree); - } - } + Check.NotNull(binaryExpression, nameof(binaryExpression)); - return base.VisitExtension(node); + if (binaryExpression.Left.Type == typeof(AnonymousObject) + && binaryExpression.NodeType == ExpressionType.Equal) + { + return Visit(ConvertAnonymousObjectEqualityComparison(binaryExpression)); } - } - protected override Expression VisitMember(MemberExpression memberExpression) - { - Check.NotNull(memberExpression, nameof(memberExpression)); + var uncheckedNodeTypeVariant = binaryExpression.NodeType switch + { + ExpressionType.AddChecked => ExpressionType.Add, + ExpressionType.SubtractChecked => ExpressionType.Subtract, + ExpressionType.MultiplyChecked => ExpressionType.Multiply, + _ => binaryExpression.NodeType + }; - return TryBindMember(memberExpression.Expression, MemberIdentity.Create(memberExpression.Member), out var result) - ? result - : TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression) - ? null - : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type); + var left = TryRemoveImplicitConvert(binaryExpression.Left); + var right = TryRemoveImplicitConvert(binaryExpression.Right); + + return TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) + || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight) + ? null + : uncheckedNodeTypeVariant == ExpressionType.Coalesce + ? _sqlExpressionFactory.Coalesce(sqlLeft, sqlRight) + : (Expression)_sqlExpressionFactory.MakeBinary( + uncheckedNodeTypeVariant, + sqlLeft, + sqlRight, + null); } - private bool TryBindMember(Expression source, MemberIdentity member, out Expression expression) + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) { - source = source.UnwrapTypeConversion(out var convertedType); - expression = null; - if (source is EntityShaperExpression entityShaperExpression) - { - var entityType = entityShaperExpression.EntityType; - if (convertedType != null) - { - entityType = entityType.GetRootType().GetDerivedTypesInclusive() - .FirstOrDefault(et => et.ClrType == convertedType); - if (entityType == null) - { - return false; - } - } + Check.NotNull(conditionalExpression, nameof(conditionalExpression)); - var property = member.MemberInfo != null - ? entityType.FindProperty(member.MemberInfo) - : entityType.FindProperty(member.Name); - if (property != null - && Visit(entityShaperExpression.ValueBufferExpression) is EntityProjectionExpression entityProjectionExpression - && (entityProjectionExpression.EntityType.IsAssignableFrom(property.DeclaringEntityType) - || property.DeclaringEntityType.IsAssignableFrom(entityProjectionExpression.EntityType))) - { - expression = entityProjectionExpression.BindProperty(property); - return true; - } - } + var test = Visit(conditionalExpression.Test); + var ifTrue = Visit(conditionalExpression.IfTrue); + var ifFalse = Visit(conditionalExpression.IfFalse); - return false; + return TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse) + ? null + : _sqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); } - protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) + protected override Expression VisitConstant(ConstantExpression constantExpression) + => new SqlConstantExpression(Check.NotNull(constantExpression, nameof(constantExpression)), null); + + protected override Expression VisitExtension(Expression extensionExpression) { - Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); + Check.NotNull(extensionExpression, nameof(extensionExpression)); - if (typeBinaryExpression.NodeType == ExpressionType.TypeIs - && Visit(typeBinaryExpression.Expression) is EntityProjectionExpression entityProjectionExpression) + switch (extensionExpression) { - var entityType = entityProjectionExpression.EntityType; - if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) - { - return SqlExpressionFactory.Constant(true); - } + case EntityProjectionExpression _: + case SqlExpression _: + return extensionExpression; - var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); - if (derivedType != null) - { - var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); - var discriminatorColumn = entityProjectionExpression.BindProperty(entityType.GetDiscriminatorProperty()); + case EntityShaperExpression entityShaperExpression: + return new EntityReferenceExpression(entityShaperExpression); - return concreteEntityTypes.Count == 1 - ? SqlExpressionFactory.Equal( - discriminatorColumn, - SqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) - : (Expression)SqlExpressionFactory.In( - discriminatorColumn, - SqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), - negated: false); - } + case ProjectionBindingExpression projectionBindingExpression: + return projectionBindingExpression.ProjectionMember != null + ? ((SelectExpression)projectionBindingExpression.QueryExpression) + .GetMappedProjection(projectionBindingExpression.ProjectionMember) + : null; - return SqlExpressionFactory.Constant(false); + default: + return null; } - - return null; } - private Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) - { - if (methodCallExpression.Arguments.Count == 1) - { - return groupByShaperExpression.ElementSelector; - } + protected override Expression VisitInvocation(InvocationExpression invocationExpression) => null; + protected override Expression VisitLambda(Expression lambdaExpression) => null; + protected override Expression VisitListInit(ListInitExpression listInitExpression) => null; - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } - - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); - } - - private Expression GetPredicate(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + protected override Expression VisitMember(MemberExpression memberExpression) { - if (methodCallExpression.Arguments.Count == 1) - { - return null; - } + Check.NotNull(memberExpression, nameof(memberExpression)); - if (methodCallExpression.Arguments.Count == 2) - { - var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - return ReplacingExpressionVisitor.Replace( - selectorLambda.Parameters[0], - groupByShaperExpression.ElementSelector, - selectorLambda.Body); - } + var innerExpression = Visit(memberExpression.Expression); - throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + return TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member)) + ?? (TranslationFailed(memberExpression.Expression, base.Visit(memberExpression.Expression), out var sqlInnerExpression) + ? null + : Dependencies.MemberTranslatorProvider.Translate(sqlInnerExpression, memberExpression.Member, memberExpression.Type)); } + protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) + => GetConstantOrNull(Check.NotNull(memberInitExpression, nameof(memberInitExpression))); + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { Check.NotNull(methodCallExpression, nameof(methodCallExpression)); @@ -369,18 +324,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF.Property case if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName)) { - if (TryBindMember(source, MemberIdentity.Create(propertyName), out var result)) - { - return result; - } - - throw new InvalidOperationException(CoreStrings.EFPropertyCalledWithWrongPropertyName); + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName)) + ?? throw new InvalidOperationException(CoreStrings.QueryUnableToTranslateEFProperty(methodCallExpression.Print())); } // EF Indexer property if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) { - return TryBindMember(source, MemberIdentity.Create(propertyName), out var result) ? result : null; + return TryBindMember(Visit(source), MemberIdentity.Create(propertyName)); } // GroupBy Aggregate case @@ -391,12 +342,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { var translatedAggregate = methodCallExpression.Method.Name switch { - nameof(Enumerable.Average) => TranslateAverage(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Count) => TranslateCount(GetPredicate(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicate(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Max) => TranslateMax(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Min) => TranslateMin(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Sum) => TranslateSum(GetSelector(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Average) => TranslateAverage(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Count) => TranslateCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Max) => TranslateMax(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Min) => TranslateMin(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), + nameof(Enumerable.Sum) => TranslateSum(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)), _ => null }; @@ -434,8 +385,10 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return null; } - var subquery = (SelectExpression)subqueryTranslation.QueryExpression; - subquery.ApplyProjection(); + if (subqueryTranslation.ShaperExpression is EntityShaperExpression entityShaperExpression) + { + return new EntityReferenceExpression(subqueryTranslation); + } if (!(subqueryTranslation.ShaperExpression is ProjectionBindingExpression || IsAggregateResultWithCustomShaper(methodCallExpression.Method))) @@ -443,6 +396,9 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return null; } + var subquery = (SelectExpression)subqueryTranslation.QueryExpression; + subquery.ApplyProjection(); + #pragma warning disable IDE0046 // Convert to conditional expression if (subquery.Tables.Count == 0 #pragma warning restore IDE0046 // Convert to conditional expression @@ -479,6 +435,183 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); } + protected override Expression VisitNew(NewExpression newExpression) + => GetConstantOrNull(Check.NotNull(newExpression, nameof(newExpression))); + + protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) => null; + + protected override Expression VisitParameter(ParameterExpression parameterExpression) + => new SqlParameterExpression(Check.NotNull(parameterExpression, nameof(parameterExpression)), null); + + protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression) + { + Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression)); + + var innerExpression = Visit(typeBinaryExpression.Expression); + + if (typeBinaryExpression.NodeType == ExpressionType.TypeIs + && innerExpression is EntityReferenceExpression entityReferenceExpression) + { + var entityType = entityReferenceExpression.EntityType; + if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + { + return _sqlExpressionFactory.Constant(true); + } + + var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); + if (derivedType != null) + { + var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); + var discriminatorColumn = BindProperty(entityReferenceExpression, entityType.GetDiscriminatorProperty()); + + return concreteEntityTypes.Count == 1 + ? _sqlExpressionFactory.Equal( + discriminatorColumn, + _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) + : (Expression)_sqlExpressionFactory.In( + discriminatorColumn, + _sqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), + negated: false); + } + + return _sqlExpressionFactory.Constant(false); + } + + return null; + } + + protected override Expression VisitUnary(UnaryExpression unaryExpression) + { + Check.NotNull(unaryExpression, nameof(unaryExpression)); + + var operand = Visit(unaryExpression.Operand); + + if (operand is EntityReferenceExpression entityReferenceExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked + || unaryExpression.NodeType == ExpressionType.TypeAs)) + { + return entityReferenceExpression.Convert(unaryExpression.Type); + } + + if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) + { + return null; + } + + switch (unaryExpression.NodeType) + { + case ExpressionType.Not: + return _sqlExpressionFactory.Not(sqlOperand); + + case ExpressionType.Negate: + return _sqlExpressionFactory.Negate(sqlOperand); + + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.TypeAs: + // Object convert needs to be converted to explicit cast when mismatching types + if (operand.Type.IsInterface + && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + || unaryExpression.Type.UnwrapNullableType() == operand.Type.UnwrapNullableType() + || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) + { + return sqlOperand; + } + + // Introduce explicit cast only if the target type is mapped else we need to client eval + if (unaryExpression.Type == typeof(object) + || _sqlExpressionFactory.FindMapping(unaryExpression.Type) != null) + { + sqlOperand = _sqlExpressionFactory.ApplyDefaultTypeMapping(sqlOperand); + + return _sqlExpressionFactory.Convert(sqlOperand, unaryExpression.Type); + } + + break; + + case ExpressionType.Quote: + return operand; + } + + return null; + } + + private Expression TryBindMember(Expression source, MemberIdentity member) + { + if (!(source is EntityReferenceExpression entityReferenceExpression)) + { + return null; + } + + var entityType = entityReferenceExpression.EntityType; + var property = member.MemberInfo != null + ? entityType.FindProperty(member.MemberInfo) + : entityType.FindProperty(member.Name); + + return property != null ? BindProperty(entityReferenceExpression, property) : null; + } + + private SqlExpression BindProperty(EntityReferenceExpression entityReferenceExpression, IProperty property) + { + if (entityReferenceExpression.ParameterEntity != null) + { + return ((EntityProjectionExpression)Visit(entityReferenceExpression.ParameterEntity.ValueBufferExpression)).BindProperty(property); + } + + if (entityReferenceExpression.SubqueryEntity != null) + { + var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; + var innerProjection = ((EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression)).BindProperty(property); + var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; + subSelectExpression.AddToProjection(innerProjection); + + return new ScalarSubqueryExpression(subSelectExpression); + } + + return null; + } + + private static Expression GetSelectorOnGrouping( + MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return groupByShaperExpression.ElementSelector; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + + private static Expression GetPredicateOnGrouping( + MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression) + { + if (methodCallExpression.Arguments.Count == 1) + { + return null; + } + + if (methodCallExpression.Arguments.Count == 2) + { + var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + return ReplacingExpressionVisitor.Replace( + selectorLambda.Parameters[0], + groupByShaperExpression.ElementSelector, + selectorLambda.Body); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print())); + } + private static Expression TryRemoveImplicitConvert(Expression expression) { if (expression is UnaryExpression unaryExpression) @@ -510,7 +643,7 @@ private static Expression TryRemoveImplicitConvert(Expression expression) return expression; } - private Expression ConvertAnonymousObjectEqualityComparison(BinaryExpression binaryExpression) + private static Expression ConvertAnonymousObjectEqualityComparison(BinaryExpression binaryExpression) { var leftExpressions = ((NewArrayExpression)((NewExpression)binaryExpression.Left).Arguments[0]).Expressions; var rightExpressions = ((NewArrayExpression)((NewExpression)binaryExpression.Right).Arguments[0]).Expressions; @@ -542,49 +675,14 @@ static Expression RemoveObjectConvert(Expression expression) : expression; } - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - Check.NotNull(binaryExpression, nameof(binaryExpression)); - - if (binaryExpression.Left.Type == typeof(AnonymousObject) - && binaryExpression.NodeType == ExpressionType.Equal) - { - return Visit(ConvertAnonymousObjectEqualityComparison(binaryExpression)); - } - - var uncheckedNodeTypeVariant = binaryExpression.NodeType switch - { - ExpressionType.AddChecked => ExpressionType.Add, - ExpressionType.SubtractChecked => ExpressionType.Subtract, - ExpressionType.MultiplyChecked => ExpressionType.Multiply, - _ => binaryExpression.NodeType - }; - - var left = TryRemoveImplicitConvert(binaryExpression.Left); - var right = TryRemoveImplicitConvert(binaryExpression.Right); - - return TranslationFailed(binaryExpression.Left, Visit(left), out var sqlLeft) - || TranslationFailed(binaryExpression.Right, Visit(right), out var sqlRight) - ? null - : uncheckedNodeTypeVariant == ExpressionType.Coalesce - ? SqlExpressionFactory.Coalesce(sqlLeft, sqlRight) - : (Expression)SqlExpressionFactory.MakeBinary( - uncheckedNodeTypeVariant, - sqlLeft, - sqlRight, - null); - } - - private SqlConstantExpression GetConstantOrNull(Expression expression) - { - if (CanEvaluate(expression)) - { - var value = Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(); - return new SqlConstantExpression(Expression.Constant(value, expression.Type), null); - } - - return null; - } + private static SqlConstantExpression GetConstantOrNull(Expression expression) + => CanEvaluate(expression) + ? new SqlConstantExpression( + Expression.Constant( + Expression.Lambda>(Expression.Convert(expression, typeof(object))).Compile().Invoke(), + expression.Type), + null) + : null; private static bool CanEvaluate(Expression expression) { @@ -608,162 +706,81 @@ private static bool CanEvaluate(Expression expression) } } - protected override Expression VisitNew(NewExpression node) - { - Check.NotNull(node, nameof(node)); - - return GetConstantOrNull(node); - } - - protected override Expression VisitMemberInit(MemberInitExpression node) - { - Check.NotNull(node, nameof(node)); - - return GetConstantOrNull(node); - } - - protected override Expression VisitNewArray(NewArrayExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitListInit(ListInitExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitInvocation(InvocationExpression node) - { - Check.NotNull(node, nameof(node)); - - return null; - } - - protected override Expression VisitLambda(Expression node) - { - Check.NotNull(node, nameof(node)); - - return node.Body != null ? Visit(node.Body) : null; - } - - protected override Expression VisitConstant(ConstantExpression constantExpression) - { - Check.NotNull(constantExpression, nameof(constantExpression)); - - return new SqlConstantExpression(constantExpression, null); - } - - protected override Expression VisitParameter(ParameterExpression parameterExpression) - { - Check.NotNull(parameterExpression, nameof(parameterExpression)); - - return new SqlParameterExpression(parameterExpression, null); - } - - protected override Expression VisitExtension(Expression extensionExpression) + [DebuggerStepThrough] + private static bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - switch (extensionExpression) + if (original != null + && !(translation is SqlExpression)) { - case EntityProjectionExpression _: - case SqlExpression _: - return extensionExpression; - - case EntityShaperExpression entityShaperExpression: - return Visit(entityShaperExpression.ValueBufferExpression); - - case ProjectionBindingExpression projectionBindingExpression: - return projectionBindingExpression.ProjectionMember != null - ? ((SelectExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember) - : null; - - default: - return null; + castTranslation = null; + return true; } - } - protected override Expression VisitConditional(ConditionalExpression conditionalExpression) - { - Check.NotNull(conditionalExpression, nameof(conditionalExpression)); - - var test = Visit(conditionalExpression.Test); - var ifTrue = Visit(conditionalExpression.IfTrue); - var ifFalse = Visit(conditionalExpression.IfFalse); - - return TranslationFailed(conditionalExpression.Test, test, out var sqlTest) - || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) - || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse) - ? null - : SqlExpressionFactory.Case(new[] { new CaseWhenClause(sqlTest, sqlIfTrue) }, sqlIfFalse); + castTranslation = translation as SqlExpression; + return false; } - protected override Expression VisitUnary(UnaryExpression unaryExpression) + private sealed class EntityReferenceExpression : Expression { - Check.NotNull(unaryExpression, nameof(unaryExpression)); - - var operand = Visit(unaryExpression.Operand); + public EntityReferenceExpression(EntityShaperExpression parameter) + { + ParameterEntity = parameter; + EntityType = parameter.EntityType; + } - if (TranslationFailed(unaryExpression.Operand, operand, out var sqlOperand)) + public EntityReferenceExpression(ShapedQueryExpression subquery) { - return null; + SubqueryEntity = subquery; + EntityType = ((EntityShaperExpression)subquery.ShaperExpression).EntityType; } - switch (unaryExpression.NodeType) + private EntityReferenceExpression(EntityReferenceExpression entityReferenceExpression, IEntityType entityType) { - case ExpressionType.Not: - return SqlExpressionFactory.Not(sqlOperand); + ParameterEntity = entityReferenceExpression.ParameterEntity; + SubqueryEntity = entityReferenceExpression.SubqueryEntity; + EntityType = entityType; + } - case ExpressionType.Negate: - return SqlExpressionFactory.Negate(sqlOperand); + public EntityShaperExpression ParameterEntity { get; } + public ShapedQueryExpression SubqueryEntity { get; } + public IEntityType EntityType { get; } - case ExpressionType.Convert: - case ExpressionType.ConvertChecked: - case ExpressionType.TypeAs: - // Object convert needs to be converted to explicit cast when mismatching types - if (operand.Type.IsInterface - && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) - || unaryExpression.Type.UnwrapNullableType() == operand.Type.UnwrapNullableType() - || unaryExpression.Type.UnwrapNullableType() == typeof(Enum)) - { - return sqlOperand; - } - - // Introduce explicit cast only if the target type is mapped else we need to client eval - if (unaryExpression.Type == typeof(object) - || SqlExpressionFactory.FindMapping(unaryExpression.Type) != null) - { - sqlOperand = SqlExpressionFactory.ApplyDefaultTypeMapping(sqlOperand); + public override Type Type => EntityType.ClrType; + public override ExpressionType NodeType => ExpressionType.Extension; - return SqlExpressionFactory.Convert(sqlOperand, unaryExpression.Type); - } + public Expression Convert(Type type) + { + if (type == typeof(object) // Ignore object conversion + || type.IsAssignableFrom(Type)) // Ignore casting to base type/interface + { + return this; + } - break; + var derivedEntityType = EntityType.GetDerivedTypes().FirstOrDefault(et => et.ClrType == type); - case ExpressionType.Quote: - return operand; + return derivedEntityType == null ? null : new EntityReferenceExpression(this, derivedEntityType); } - - return null; } - [DebuggerStepThrough] - private bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) + private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor { - if (original != null - && !(translation is SqlExpression)) + protected override Expression VisitExtension(Expression extensionExpression) { - castTranslation = null; - return true; - } + Check.NotNull(extensionExpression, nameof(extensionExpression)); - castTranslation = translation as SqlExpression; - return false; + if (extensionExpression is SqlExpression sqlExpression + && !(extensionExpression is SqlFragmentExpression) + && !(extensionExpression is SqlFunctionExpression sqlFunctionExpression + && sqlFunctionExpression.Type.IsQueryableType())) + { + if (sqlExpression.TypeMapping == null) + { + throw new InvalidOperationException(CoreStrings.NullTypeMappingInSqlTree); + } + } + + return base.VisitExtension(extensionExpression); + } } } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 36a4cc663dc..8b3530ffdc7 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -45,19 +45,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { Check.NotNull(binaryExpression, nameof(binaryExpression)); - var visitedExpression = (SqlExpression)base.VisitBinary(binaryExpression); - - if (visitedExpression == null) - { - return null; - } - - return visitedExpression is SqlBinaryExpression sqlBinary - && _arithmeticOperatorTypes.Contains(sqlBinary.OperatorType) - && (_dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Left)) - || _dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Right))) - ? null - : visitedExpression; + return !(base.VisitBinary(binaryExpression) is SqlExpression visitedExpression) + ? (Expression)null + : (Expression)(visitedExpression is SqlBinaryExpression sqlBinary + && _arithmeticOperatorTypes.Contains(sqlBinary.OperatorType) + && (_dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Left)) + || _dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Right))) + ? null + : visitedExpression); } protected override Expression VisitUnary(UnaryExpression unaryExpression) @@ -65,15 +60,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) if (unaryExpression.NodeType == ExpressionType.ArrayLength && unaryExpression.Operand.Type == typeof(byte[])) { - var sqlExpression = base.Visit(unaryExpression.Operand) as SqlExpression; - - if (sqlExpression == null) + if (!(base.Visit(unaryExpression.Operand) is SqlExpression sqlExpression)) { return null; } var isBinaryMaxDataType = GetProviderType(sqlExpression) == "varbinary(max)" || sqlExpression is SqlParameterExpression; - var dataLengthSqlFunction = SqlExpressionFactory.Function( + var dataLengthSqlFunction = Dependencies.SqlExpressionFactory.Function( "DATALENGTH", new[] { sqlExpression }, nullable: true, @@ -81,7 +74,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) isBinaryMaxDataType ? typeof(long) : typeof(int)); return isBinaryMaxDataType - ? (Expression)SqlExpressionFactory.Convert(dataLengthSqlFunction, typeof(int)) + ? (Expression)Dependencies.SqlExpressionFactory.Convert(dataLengthSqlFunction, typeof(int)) : dataLengthSqlFunction; } @@ -96,18 +89,15 @@ public override SqlExpression TranslateLongCount(Expression expression = null) return null; } - return SqlExpressionFactory.ApplyDefaultTypeMapping( - SqlExpressionFactory.Function( + return Dependencies.SqlExpressionFactory.ApplyDefaultTypeMapping( + Dependencies.SqlExpressionFactory.Function( "COUNT_BIG", - new[] { SqlExpressionFactory.Fragment("*") }, + new[] { Dependencies.SqlExpressionFactory.Fragment("*") }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(long))); } - private static string GetProviderType(SqlExpression expression) - { - return expression.TypeMapping?.StoreType; - } + private static string GetProviderType(SqlExpression expression) => expression.TypeMapping?.StoreType; } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 56f8fdf6d4e..06a24c1c003 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -94,7 +94,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) && unaryExpression.Operand.Type == typeof(byte[])) { return base.Visit(unaryExpression.Operand) is SqlExpression sqlExpression - ? SqlExpressionFactory.Function( + ? Dependencies.SqlExpressionFactory.Function( "length", new[] { sqlExpression }, nullable: true, @@ -127,9 +127,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) { Check.NotNull(binaryExpression, nameof(binaryExpression)); - var visitedExpression = (SqlExpression)base.VisitBinary(binaryExpression); - - if (visitedExpression == null) + if (!(base.VisitBinary(binaryExpression) is SqlExpression visitedExpression)) { return null; } @@ -140,7 +138,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) && (_functionModuloTypes.Contains(GetProviderType(sqlBinary.Left)) || _functionModuloTypes.Contains(GetProviderType(sqlBinary.Right)))) { - return SqlExpressionFactory.Function( + return Dependencies.SqlExpressionFactory.Function( "ef_mod", new[] { sqlBinary.Left, sqlBinary.Right }, nullable: true, @@ -242,20 +240,20 @@ private static bool AttemptDecimalCompare(SqlBinaryExpression sqlBinary) => private Expression DoDecimalCompare(SqlExpression visitedExpression, ExpressionType op, SqlExpression left, SqlExpression right) { - var actual = SqlExpressionFactory.Function( + var actual = Dependencies.SqlExpressionFactory.Function( name: "ef_compare", new[] { left, right }, nullable: true, new[] { true, true }, typeof(int)); - var oracle = SqlExpressionFactory.Constant(value: 0); + var oracle = Dependencies.SqlExpressionFactory.Constant(value: 0); return op switch { - ExpressionType.GreaterThan => SqlExpressionFactory.GreaterThan(left: actual, right: oracle), - ExpressionType.GreaterThanOrEqual => SqlExpressionFactory.GreaterThanOrEqual(left: actual, right: oracle), - ExpressionType.LessThan => SqlExpressionFactory.LessThan(left: actual, right: oracle), - ExpressionType.LessThanOrEqual => SqlExpressionFactory.LessThanOrEqual(left: actual, right: oracle), + ExpressionType.GreaterThan => Dependencies.SqlExpressionFactory.GreaterThan(left: actual, right: oracle), + ExpressionType.GreaterThanOrEqual => Dependencies.SqlExpressionFactory.GreaterThanOrEqual(left: actual, right: oracle), + ExpressionType.LessThan => Dependencies.SqlExpressionFactory.LessThan(left: actual, right: oracle), + ExpressionType.LessThanOrEqual => Dependencies.SqlExpressionFactory.LessThanOrEqual(left: actual, right: oracle), _ => visitedExpression }; } diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 465ea5761da..dab2e1a0d33 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2153,7 +2153,7 @@ public static string PropertyWrongName([CanBeNull] object property, [CanBeNull] property, entityType, clrName); /// - /// The indexed property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name. + /// The indexer property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name. /// public static string PropertyClashingNonIndexer([CanBeNull] object property, [CanBeNull] object entityType) => string.Format( @@ -2535,10 +2535,12 @@ public static string UnsupportedBinaryOperator => GetString("UnsupportedBinaryOperator"); /// - /// EF.Property called with wrong property name. + /// Translation of '{expression}' failed. Either source is not an entity type or the specified property does not exist on the entity type. /// - public static string EFPropertyCalledWithWrongPropertyName - => GetString("EFPropertyCalledWithWrongPropertyName"); + public static string QueryUnableToTranslateEFProperty([CanBeNull] object expression) + => string.Format( + GetString("QueryUnableToTranslateEFProperty", nameof(expression)), + expression); /// /// Invalid {state} encountered. diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index 6d70da84fd6..a66b5a6e5f4 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1341,8 +1341,8 @@ Unsupported Binary operator type specified. - - EF.Property called with wrong property name. + + Translation of '{expression}' failed. Either source is not an entity type or the specified property does not exist on the entity type. Invalid {state} encountered. diff --git a/src/EFCore/Query/EntityShaperExpression.cs b/src/EFCore/Query/EntityShaperExpression.cs index 302786bc85b..a594c307e96 100644 --- a/src/EFCore/Query/EntityShaperExpression.cs +++ b/src/EFCore/Query/EntityShaperExpression.cs @@ -10,7 +10,6 @@ using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities;