Skip to content

Commit

Permalink
Query: Move processing of aggregate functions inside SqlTranslator
Browse files Browse the repository at this point in the history
This allows us to use those methods to translate aggregates after GroupBy

Resolves #15718
  • Loading branch information
smitpatel committed Jul 2, 2019
1 parent 37415eb commit b62c2a8
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,34 +151,12 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (selector != null)
{
source = TranslateSelect(source, selector);
}
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

var inputType = projection.Type.UnwrapNullableType();
if (inputType == typeof(int)
|| inputType == typeof(long))
{
projection = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(projection, typeof(double)));
}

if (inputType == typeof(float))
{
projection = _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"AVG", new[] { projection }, typeof(double), null),
projection.Type,
projection.TypeMapping);
}
else
{
projection = _sqlExpressionFactory.Function(
"AVG", new[] { projection }, projection.Type, projection.TypeMapping);
}
var projection = _sqlTranslator.TranslateAverage(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand Down Expand Up @@ -237,8 +215,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
source = TranslateWhere(source, predicate);
}

var translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int)));
var translation = _sqlTranslator.TranslateCount();

var projectionMapping = new Dictionary<ProjectionMember, Expression>
{
Expand Down Expand Up @@ -480,8 +457,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
source = TranslateWhere(source, predicate);
}

var translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
var translation = _sqlTranslator.TranslateLongCount();
var projectionMapping = new Dictionary<ProjectionMember, Expression>
{
{ new ProjectionMember(), translation }
Expand All @@ -499,14 +475,12 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);
var projection = _sqlTranslator.TranslateMax(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand All @@ -516,14 +490,12 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);
var projection = _sqlTranslator.TranslateMin(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand Down Expand Up @@ -737,27 +709,12 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var serverOutputType = resultType.UnwrapNullableType();
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

if (serverOutputType == typeof(float))
{
projection = _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function("SUM", new[] { projection }, typeof(double)),
serverOutputType,
projection.TypeMapping);
}
else
{
projection = _sqlExpressionFactory.Function(
"SUM", new[] { projection }, serverOutputType, projection.TypeMapping);
}
var projection = _sqlTranslator.TranslateSum(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public RelationalSqlTranslatingExpressionVisitor(
_sqlVerifyingExpressionVisitor = new SqlTypeMappingVerifyingExpressionVisitor();
}

public SqlExpression Translate(Expression expression)
public virtual SqlExpression Translate(Expression expression)
{
var result = Visit(expression);

Expand All @@ -64,6 +64,83 @@ public SqlExpression Translate(Expression expression)
return null;
}

public virtual SqlExpression TranslateAverage(Expression expression)
{
if (!(expression is SqlExpression sqlExpression))
{
sqlExpression = Translate(expression);
}

var inputType = sqlExpression.Type.UnwrapNullableType();
if (inputType == typeof(int)
|| inputType == typeof(long))
{
sqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
}

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"AVG", new[] { sqlExpression }, typeof(double), null),
sqlExpression.Type,
sqlExpression.TypeMapping)
: (SqlExpression)_sqlExpressionFactory.Function(
"AVG", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping);
}

public virtual SqlExpression TranslateCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int)));
}

public virtual SqlExpression TranslateLongCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
}

public virtual SqlExpression TranslateMax(Expression expression)
{
if (!(expression is SqlExpression sqlExpression))
{
sqlExpression = Translate(expression);
}

return _sqlExpressionFactory.Function("MAX", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping);
}

public virtual SqlExpression TranslateMin(Expression expression)
{
if (!(expression is SqlExpression sqlExpression))
{
sqlExpression = Translate(expression);
}

return _sqlExpressionFactory.Function("MIN", new[] { sqlExpression }, sqlExpression.Type, sqlExpression.TypeMapping);
}

public virtual SqlExpression TranslateSum(Expression expression)
{
if (!(expression is SqlExpression sqlExpression))
{
sqlExpression = Translate(expression);
}

var inputType = sqlExpression.Type.UnwrapNullableType();

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function("SUM", new[] { sqlExpression }, typeof(double)),
inputType,
sqlExpression.TypeMapping)
: (SqlExpression)_sqlExpressionFactory.Function(
"SUM", new[] { sqlExpression }, inputType, sqlExpression.TypeMapping);
}

private class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ private static readonly HashSet<ExpressionType> _arithmeticOperatorTypes
ExpressionType.Divide,
ExpressionType.Modulo,
};
// TODO: Possibly make this protected in base
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public SqlServerSqlTranslatingExpressionVisitor(
IModel model,
Expand All @@ -40,6 +42,7 @@ public SqlServerSqlTranslatingExpressionVisitor(
IMethodCallTranslatorProvider methodCallTranslatorProvider)
: base(model, queryableMethodTranslatingExpressionVisitor, sqlExpressionFactory, memberTranslatorProvider, methodCallTranslatorProvider)
{
_sqlExpressionFactory = sqlExpressionFactory;
}

protected override Expression VisitBinary(BinaryExpression binaryExpression)
Expand All @@ -59,6 +62,13 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
: visitedExpression;
}

public override SqlExpression TranslateLongCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT_BIG", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
}

private static string GetProviderType(SqlExpression expression)
{
return expression.TypeMapping?.StoreType;
Expand Down
13 changes: 13 additions & 0 deletions test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public OwnedQueryCosmosTest(OwnedQueryCosmosFixture fixture, ITestOutputHelper t
//TestLoggerFactory.TestOutputHelper = testOutputHelper;
}

[ConditionalFact(Skip = "#16392")]
public override void Navigation_rewrite_on_owned_collection()
{
base.Navigation_rewrite_on_owned_collection();
Expand All @@ -30,6 +31,18 @@ FROM root c
WHERE ((c[""Discriminator""] = ""LeafB"") OR ((c[""Discriminator""] = ""LeafA"") OR ((c[""Discriminator""] = ""Branch"") OR (c[""Discriminator""] = ""OwnedPerson""))))");
}

[ConditionalFact(Skip = "#16392")]
public override void Navigation_rewrite_on_owned_collection_with_composition()
{
base.Navigation_rewrite_on_owned_collection_with_composition();
}

[ConditionalFact(Skip = "#16392")]
public override void Navigation_rewrite_on_owned_collection_with_composition_complex()
{
base.Navigation_rewrite_on_owned_collection_with_composition_complex();
}

[ConditionalFact(Skip = "Owned collection #12086")]
public override void Navigation_rewrite_on_owned_reference_projecting_entity()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1864,5 +1864,11 @@ SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND @__p_0)");
}

[ConditionalTheory(Skip = "Issue#16391")]
public override Task Where_is_conditional(bool isAsync)
{
return base.Where_is_conditional(isAsync);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public virtual Task Union_nested(bool isAsync)
.Union(cs.Where(e => e.City == "London")),
entryCount: 25);

[ConditionalTheory]
[ConditionalTheory(Skip = "Issue#16365")]
[MemberData(nameof(IsAsyncData))]
public virtual void Union_non_entity(bool isAsync)
=> AssertQuery<Customer>(isAsync, cs => cs
Expand Down
14 changes: 7 additions & 7 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2749,7 +2749,7 @@ from o in os
select c.CustomerID);
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task SelectMany_LongCount(bool isAsync)
{
Expand Down Expand Up @@ -5197,7 +5197,7 @@ public virtual Task Select_orderBy_take_count(bool isAsync)
cs => cs.OrderBy(c => c.Country).Take(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_take_long_count(bool isAsync)
{
Expand All @@ -5206,7 +5206,7 @@ public virtual Task Select_take_long_count(bool isAsync)
cs => cs.Take(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_orderBy_take_long_count(bool isAsync)
{
Expand Down Expand Up @@ -5269,7 +5269,7 @@ public virtual Task Select_orderBy_skip_count(bool isAsync)
cs => cs.OrderBy(c => c.Country).Skip(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_skip_long_count(bool isAsync)
{
Expand All @@ -5278,7 +5278,7 @@ public virtual Task Select_skip_long_count(bool isAsync)
cs => cs.Skip(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_orderBy_skip_long_count(bool isAsync)
{
Expand Down Expand Up @@ -5332,7 +5332,7 @@ public virtual Task Select_distinct_count(bool isAsync)
cs => cs.Distinct());
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_distinct_long_count(bool isAsync)
{
Expand Down Expand Up @@ -5428,7 +5428,7 @@ orderby c.CustomerID
select c.CustomerID);
}

[ConditionalTheory(Skip = "Issue#16365")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Comparing_navigations_using_Equals(bool isAsync)
{
Expand Down
Loading

0 comments on commit b62c2a8

Please sign in to comment.