Skip to content

Commit

Permalink
Translate AVG/SUM/MIN/MAX
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed Feb 22, 2019
1 parent c2ec462 commit 65dbb6e
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour
return source;
}

protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector)
protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
=> TranslateScalarAggregate(source, selector, nameof(Enumerable.Average));

protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException();
Expand Down Expand Up @@ -192,10 +192,10 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
return source;
}

protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector)
protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
=> TranslateScalarAggregate(source, selector, nameof(Enumerable.Max));

protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector)
protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
=> TranslateScalarAggregate(source, selector, nameof(Enumerable.Min));

protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException();
Expand Down Expand Up @@ -269,7 +269,7 @@ protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression sou

protected override ShapedQueryExpression TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate) => throw new NotImplementedException();

protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector)
protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
=> TranslateScalarAggregate(source, selector, nameof(Enumerable.Sum));

protected override ShapedQueryExpression TranslateTake(ShapedQueryExpression source, Expression count)
Expand Down
19 changes: 14 additions & 5 deletions src/EFCore.Relational/Query/PipeLine/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,20 @@ protected override Expression VisitTable(TableExpression tableExpression)

protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpression)
{
Visit(sqlBinaryExpression.Left);

_relationalCommandBuilder.Append(_operatorMap[sqlBinaryExpression.OperatorType]);

Visit(sqlBinaryExpression.Right);
if (sqlBinaryExpression.OperatorType == ExpressionType.Coalesce)
{
_relationalCommandBuilder.Append("COALESCE(");
Visit(sqlBinaryExpression.Left);
_relationalCommandBuilder.Append(", ");
Visit(sqlBinaryExpression.Right);
_relationalCommandBuilder.Append(")");
}
else
{
Visit(sqlBinaryExpression.Left);
_relationalCommandBuilder.Append(_operatorMap[sqlBinaryExpression.OperatorType]);
Visit(sqlBinaryExpression.Right);
}

return sqlBinaryExpression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.PipeLine;
using Microsoft.EntityFrameworkCore.Relational.Query.PipeLine.SqlExpressions;
using Microsoft.EntityFrameworkCore.Storage;
Expand Down Expand Up @@ -79,7 +80,57 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour
return source;
}

protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException();
protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
{
if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)((SelectExpression)source.QueryExpression)
.GetProjectionExpression(new ProjectionMember());

var inputType = projection.Type.UnwrapNullableType();
if (inputType == typeof(int)
|| inputType == typeof(long))
{
projection = new SqlCastExpression(projection, typeof(double), _typeMappingSource.FindMapping(typeof(double)));
}

if (projection.Type.UnwrapNullableType() == typeof(float))
{
projection = new SqlCastExpression(
new SqlFunctionExpression(
null,
"AVG",
null,
new[]
{
projection
},
typeof(double),
_typeMappingSource.FindMapping(typeof(double)),
false),
projection.Type,
projection.TypeMapping);
}
else
{
projection = new SqlFunctionExpression(
null,
"AVG",
null,
new[]
{
projection
},
projection.Type,
projection.TypeMapping,
false);
}

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}

protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException();

Expand Down Expand Up @@ -183,9 +234,55 @@ protected override ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpre

protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpression source, LambdaExpression predicate) => throw new NotImplementedException();

protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException();
protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
{
if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)((SelectExpression)source.QueryExpression)
.GetProjectionExpression(new ProjectionMember());

projection = new SqlFunctionExpression(
null,
"MAX",
null,
new[]
{
projection
},
resultType,
projection.TypeMapping,
false);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}

protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
{
if (selector != null)
{
source = TranslateSelect(source, selector);
}

protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException();
var projection = (SqlExpression)((SelectExpression)source.QueryExpression)
.GetProjectionExpression(new ProjectionMember());

projection = new SqlFunctionExpression(
null,
"MIN",
null,
new[]
{
projection
},
resultType,
projection.TypeMapping,
false);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}

protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType) => throw new NotImplementedException();

Expand Down Expand Up @@ -265,7 +362,51 @@ protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression sou

protected override ShapedQueryExpression TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate) => throw new NotImplementedException();

protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector) => throw new NotImplementedException();
protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
{
if (selector != null)
{
source = TranslateSelect(source, selector);
}

var serverOutputType = resultType.UnwrapNullableType();
var projection = (SqlExpression)((SelectExpression)source.QueryExpression)
.GetProjectionExpression(new ProjectionMember());

if (serverOutputType == typeof(float))
{
projection = new SqlCastExpression(
new SqlFunctionExpression(
null,
"SUM",
null,
new[]
{
projection
},
typeof(double),
_typeMappingSource.FindMapping(typeof(double)),
false),
serverOutputType,
projection.TypeMapping);
}
else
{
projection = new SqlFunctionExpression(
null,
"SUM",
null,
new[]
{
projection
},
serverOutputType,
projection.TypeMapping,
false);
}

return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
}

protected override ShapedQueryExpression TranslateTake(ShapedQueryExpression source, Expression count)
{
Expand Down Expand Up @@ -331,5 +472,50 @@ private SqlExpression TranslateLambdaExpression(

return TranslateExpression((SelectExpression)shapedQueryExpression.QueryExpression, lambdaBody, condition);
}

private ShapedQueryExpression AggregateResultShaper(
ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.ApplyProjection(
new Dictionary<ProjectionMember, Expression>
{
{ new ProjectionMember(), projection }
});

Expression shaper = new ProjectionBindingExpression(selectExpression, new ProjectionMember(), projection.Type);

if (throwOnNullResult)
{
var resultVariable = Expression.Variable(projection.Type, "result");

shaper = Expression.Block(
new[] { resultVariable },
Expression.Assign(resultVariable, shaper),
Expression.Condition(
Expression.Equal(resultVariable, Expression.Default(projection.Type)),
Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(RelationalStrings.NoElements)),
projection.Type),
resultVariable),
resultType != resultVariable.Type
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable);
}
else if (resultType.IsNullableType())
{
shaper = Expression.Convert(shaper, resultType);
}

source.ShaperExpression
= Expression.Lambda(
shaper,
source.ShaperExpression.Parameters);

return source;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected virtual SqlExpression ApplyTypeMappingOnSqlCast(
{
if (typeMapping == null)
{
throw new InvalidOperationException("TypeMapping should not be null.");
return sqlCastExpression;
}

var operand = ApplyTypeMapping(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
shapedQueryExpression,
methodCallExpression.Arguments.Count == 2
? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1])
: null);
: null,
methodCallExpression.Type);

case nameof(Queryable.Cast):
return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]);
Expand Down Expand Up @@ -268,15 +269,17 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
shapedQueryExpression,
methodCallExpression.Arguments.Count == 2
? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1])
: null);
: null,
methodCallExpression.Type);

case nameof(Queryable.Min):
shapedQueryExpression.ResultType = ResultType.Single;
return TranslateMin(
shapedQueryExpression,
methodCallExpression.Arguments.Count == 2
? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1])
: null);
: null,
methodCallExpression.Type);

case nameof(Queryable.OfType):
return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]);
Expand Down Expand Up @@ -349,7 +352,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
shapedQueryExpression,
methodCallExpression.Arguments.Count == 2
? UnwrapLambdaFromQuoteExpression(methodCallExpression.Arguments[1])
: null);
: null,
methodCallExpression.Type);

case nameof(Queryable.Take):
return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]);
Expand Down Expand Up @@ -409,7 +413,7 @@ private LambdaExpression UnwrapLambdaFromQuoteExpression(Expression expression)

protected abstract ShapedQueryExpression TranslateAll(ShapedQueryExpression source, LambdaExpression predicate);
protected abstract ShapedQueryExpression TranslateAny(ShapedQueryExpression source, LambdaExpression predicate);
protected abstract ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector);
protected abstract ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType);
protected abstract ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType);
protected abstract ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2);
protected abstract ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item);
Expand All @@ -425,8 +429,8 @@ private LambdaExpression UnwrapLambdaFromQuoteExpression(Expression expression)
protected abstract ShapedQueryExpression TranslateJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector);
protected abstract ShapedQueryExpression TranslateLastOrDefault(ShapedQueryExpression source, LambdaExpression predicate, bool returnDefault);
protected abstract ShapedQueryExpression TranslateLongCount(ShapedQueryExpression source, LambdaExpression predicate);
protected abstract ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector);
protected abstract ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector);
protected abstract ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType);
protected abstract ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType);
protected abstract ShapedQueryExpression TranslateOfType(ShapedQueryExpression source, Type resultType);
protected abstract ShapedQueryExpression TranslateOrderBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending);
protected abstract ShapedQueryExpression TranslateReverse(ShapedQueryExpression source);
Expand All @@ -436,7 +440,7 @@ private LambdaExpression UnwrapLambdaFromQuoteExpression(Expression expression)
protected abstract ShapedQueryExpression TranslateSingleOrDefault(ShapedQueryExpression source, LambdaExpression predicate, bool returnDefault);
protected abstract ShapedQueryExpression TranslateSkip(ShapedQueryExpression source, Expression count);
protected abstract ShapedQueryExpression TranslateSkipWhile(ShapedQueryExpression source, LambdaExpression predicate);
protected abstract ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector);
protected abstract ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType);
protected abstract ShapedQueryExpression TranslateTake(ShapedQueryExpression source, Expression count);
protected abstract ShapedQueryExpression TranslateTakeWhile(ShapedQueryExpression source, LambdaExpression predicate);
protected abstract ShapedQueryExpression TranslateThenBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending);
Expand Down

0 comments on commit 65dbb6e

Please sign in to comment.