diff --git a/src/EFCore.InMemory/Query/PipeLine/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/PipeLine/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 896bdaaa580..94bf36ffa09 100644 --- a/src/EFCore.InMemory/Query/PipeLine/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/PipeLine/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -59,7 +59,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour return source; } - protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector) + protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType) => TranslateScalarAggregate(source, selector, nameof(Enumerable.Average)); protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException(); @@ -192,10 +192,10 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio return source; } - protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector) + protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType) => TranslateScalarAggregate(source, selector, nameof(Enumerable.Max)); - protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector) + protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType) => TranslateScalarAggregate(source, selector, nameof(Enumerable.Min)); protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException(); @@ -269,7 +269,7 @@ protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression sou protected override ShapedQueryExpression TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate) => throw new NotImplementedException(); - protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector) + protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType) => TranslateScalarAggregate(source, selector, nameof(Enumerable.Sum)); protected override ShapedQueryExpression TranslateTake(ShapedQueryExpression source, Expression count) diff --git a/src/EFCore.Relational/Query/PipeLine/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/PipeLine/QuerySqlGenerator.cs index 3c654454896..1ea37e17be3 100644 --- a/src/EFCore.Relational/Query/PipeLine/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/PipeLine/QuerySqlGenerator.cs @@ -188,11 +188,20 @@ protected override Expression VisitTable(TableExpression tableExpression) protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpression) { - Visit(sqlBinaryExpression.Left); - - _relationalCommandBuilder.Append(_operatorMap[sqlBinaryExpression.OperatorType]); - - Visit(sqlBinaryExpression.Right); + if (sqlBinaryExpression.OperatorType == ExpressionType.Coalesce) + { + _relationalCommandBuilder.Append("COALESCE("); + Visit(sqlBinaryExpression.Left); + _relationalCommandBuilder.Append(", "); + Visit(sqlBinaryExpression.Right); + _relationalCommandBuilder.Append(")"); + } + else + { + Visit(sqlBinaryExpression.Left); + _relationalCommandBuilder.Append(_operatorMap[sqlBinaryExpression.OperatorType]); + Visit(sqlBinaryExpression.Right); + } return sqlBinaryExpression; } diff --git a/src/EFCore.Relational/Query/PipeLine/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/PipeLine/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 72a4e39e762..bf240ad8064 100644 --- a/src/EFCore.Relational/Query/PipeLine/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/PipeLine/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query.PipeLine; using Microsoft.EntityFrameworkCore.Relational.Query.PipeLine.SqlExpressions; using Microsoft.EntityFrameworkCore.Storage; @@ -79,7 +80,57 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour return source; } - protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType) + { + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + var projection = (SqlExpression)((SelectExpression)source.QueryExpression) + .GetProjectionExpression(new ProjectionMember()); + + var inputType = projection.Type.UnwrapNullableType(); + if (inputType == typeof(int) + || inputType == typeof(long)) + { + projection = new SqlCastExpression(projection, typeof(double), _typeMappingSource.FindMapping(typeof(double))); + } + + if (projection.Type.UnwrapNullableType() == typeof(float)) + { + projection = new SqlCastExpression( + new SqlFunctionExpression( + null, + "AVG", + null, + new[] + { + projection + }, + typeof(double), + _typeMappingSource.FindMapping(typeof(double)), + false), + projection.Type, + projection.TypeMapping); + } + else + { + projection = new SqlFunctionExpression( + null, + "AVG", + null, + new[] + { + projection + }, + projection.Type, + projection.TypeMapping, + false); + } + + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + } protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException(); @@ -183,9 +234,55 @@ protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpre protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpression source, LambdaExpression predicate) => throw new NotImplementedException(); - protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType) + { + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + var projection = (SqlExpression)((SelectExpression)source.QueryExpression) + .GetProjectionExpression(new ProjectionMember()); + + projection = new SqlFunctionExpression( + null, + "MAX", + null, + new[] + { + projection + }, + resultType, + projection.TypeMapping, + false); + + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + } + + protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType) + { + if (selector != null) + { + source = TranslateSelect(source, selector); + } - protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException(); + var projection = (SqlExpression)((SelectExpression)source.QueryExpression) + .GetProjectionExpression(new ProjectionMember()); + + projection = new SqlFunctionExpression( + null, + "MIN", + null, + new[] + { + projection + }, + resultType, + projection.TypeMapping, + false); + + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + } protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException(); @@ -265,7 +362,51 @@ protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression sou protected override ShapedQueryExpression TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate) => throw new NotImplementedException(); - protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType) + { + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + var serverOutputType = resultType.UnwrapNullableType(); + var projection = (SqlExpression)((SelectExpression)source.QueryExpression) + .GetProjectionExpression(new ProjectionMember()); + + if (serverOutputType == typeof(float)) + { + projection = new SqlCastExpression( + new SqlFunctionExpression( + null, + "SUM", + null, + new[] + { + projection + }, + typeof(double), + _typeMappingSource.FindMapping(typeof(double)), + false), + serverOutputType, + projection.TypeMapping); + } + else + { + projection = new SqlFunctionExpression( + null, + "SUM", + null, + new[] + { + projection + }, + serverOutputType, + projection.TypeMapping, + false); + } + + return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); + } protected override ShapedQueryExpression TranslateTake(ShapedQueryExpression source, Expression count) { @@ -331,5 +472,50 @@ private SqlExpression TranslateLambdaExpression( return TranslateExpression((SelectExpression)shapedQueryExpression.QueryExpression, lambdaBody, condition); } + + private ShapedQueryExpression AggregateResultShaper( + ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType) + { + var selectExpression = (SelectExpression)source.QueryExpression; + selectExpression.ApplyProjection( + new Dictionary + { + { new ProjectionMember(), projection } + }); + + Expression shaper = new ProjectionBindingExpression(selectExpression, new ProjectionMember(), projection.Type); + + if (throwOnNullResult) + { + var resultVariable = Expression.Variable(projection.Type, "result"); + + shaper = Expression.Block( + new[] { resultVariable }, + Expression.Assign(resultVariable, shaper), + Expression.Condition( + Expression.Equal(resultVariable, Expression.Default(projection.Type)), + Expression.Throw( + Expression.New( + typeof(InvalidOperationException).GetConstructors() + .Single(ci => ci.GetParameters().Length == 1), + Expression.Constant(RelationalStrings.NoElements)), + projection.Type), + resultVariable), + resultType != resultVariable.Type + ? Expression.Convert(resultVariable, resultType) + : (Expression)resultVariable); + } + else if (resultType.IsNullableType()) + { + shaper = Expression.Convert(shaper, resultType); + } + + source.ShaperExpression + = Expression.Lambda( + shaper, + source.ShaperExpression.Parameters); + + return source; + } } } diff --git a/src/EFCore.Relational/Query/PipeLine/TypeMappingApplyingExpressionVisitor.cs b/src/EFCore.Relational/Query/PipeLine/TypeMappingApplyingExpressionVisitor.cs index c1f14a14dc9..65bcf3ccc09 100644 --- a/src/EFCore.Relational/Query/PipeLine/TypeMappingApplyingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/PipeLine/TypeMappingApplyingExpressionVisitor.cs @@ -71,7 +71,7 @@ protected virtual SqlExpression ApplyTypeMappingOnSqlCast( { if (typeMapping == null) { - throw new InvalidOperationException("TypeMapping should not be null."); + return sqlCastExpression; } var operand = ApplyTypeMapping( diff --git a/src/EFCore/Query/PipeLine/QueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore/Query/PipeLine/QueryableMethodTranslatingExpressionVisitor.cs index 517d9a63090..95338be818e 100644 --- a/src/EFCore/Query/PipeLine/QueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore/Query/PipeLine/QueryableMethodTranslatingExpressionVisitor.cs @@ -57,7 +57,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp shapedQueryExpression, methodCallExpression.Arguments.Count == 2 ? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1]) - : null); + : null, + methodCallExpression.Type); case nameof(Queryable.Cast): return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); @@ -268,7 +269,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp shapedQueryExpression, methodCallExpression.Arguments.Count == 2 ? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1]) - : null); + : null, + methodCallExpression.Type); case nameof(Queryable.Min): shapedQueryExpression.ResultType = ResultType.Single; @@ -276,7 +278,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp shapedQueryExpression, methodCallExpression.Arguments.Count == 2 ? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1]) - : null); + : null, + methodCallExpression.Type); case nameof(Queryable.OfType): return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); @@ -349,7 +352,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp shapedQueryExpression, methodCallExpression.Arguments.Count == 2 ? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1]) - : null); + : null, + methodCallExpression.Type); case nameof(Queryable.Take): return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]); @@ -409,7 +413,7 @@ private LambdaExpression UnwrapLambdaFromQuoteExpression(Expression expression) protected abstract ShapedQueryExpression TranslateAll(ShapedQueryExpression source, LambdaExpression predicate); protected abstract ShapedQueryExpression TranslateAny(ShapedQueryExpression source, LambdaExpression predicate); - protected abstract ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector); + protected abstract ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType); protected abstract ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType); protected abstract ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2); protected abstract ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item); @@ -425,8 +429,8 @@ private LambdaExpression UnwrapLambdaFromQuoteExpression(Expression expression) protected abstract ShapedQueryExpression TranslateJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector); protected abstract ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpression source, LambdaExpression predicate, bool returnDefault); protected abstract ShapedQueryExpression TranslateLongCount(ShapedQueryExpression source, LambdaExpression predicate); - protected abstract ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector); - protected abstract ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector); + protected abstract ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType); + protected abstract ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType); protected abstract ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType); protected abstract ShapedQueryExpression TranslateOrderBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending); protected abstract ShapedQueryExpression TranslateReverse(ShapedQueryExpression source); @@ -436,7 +440,7 @@ private LambdaExpression UnwrapLambdaFromQuoteExpression(Expression expression) protected abstract ShapedQueryExpression TranslateSingleOrDefault(ShapedQueryExpression source, LambdaExpression predicate, bool returnDefault); protected abstract ShapedQueryExpression TranslateSkip(ShapedQueryExpression source, Expression count); protected abstract ShapedQueryExpression TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate); - protected abstract ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector); + protected abstract ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType); protected abstract ShapedQueryExpression TranslateTake(ShapedQueryExpression source, Expression count); protected abstract ShapedQueryExpression TranslateTakeWhile(ShapedQueryExpression source, LambdaExpression predicate); protected abstract ShapedQueryExpression TranslateThenBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending);