Skip to content

Commit

Permalink
Translate LongCount() on SqlServer
Browse files Browse the repository at this point in the history
Includes refactoring that makes it easier for providers to override
behavior.

Fixes dotnet#15718
  • Loading branch information
roji committed May 22, 2019
1 parent 8de54ec commit f12c064
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@ namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
{
public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly RelationalProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
protected RelationalSqlTranslatingExpressionVisitor SqlTranslator { get; }
protected RelationalProjectionBindingExpressionVisitor ProjectionBindingExpressionVisitor { get; }
protected ISqlExpressionFactory SqlExpressionFactory { get; }

public RelationalQueryableMethodTranslatingExpressionVisitor(
IModel model,
IQueryableMethodTranslatingExpressionVisitorFactory queryableMethodTranslatingExpressionVisitorFactory,
IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory,
ISqlExpressionFactory sqlExpressionFactory)
{
_sqlTranslator = relationalSqlTranslatingExpressionVisitorFactory
SqlTranslator = relationalSqlTranslatingExpressionVisitorFactory
.Create(model, queryableMethodTranslatingExpressionVisitorFactory);

_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(_sqlTranslator);
_sqlExpressionFactory = sqlExpressionFactory;
ProjectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(SqlTranslator);
SqlExpressionFactory = sqlExpressionFactory;
}

protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression source, LambdaExpression predicate)
Expand All @@ -46,15 +46,15 @@ protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression sour

if (translation != null)
{
selectExpression.ApplyPredicate(_sqlExpressionFactory.Not(translation));
selectExpression.ApplyPredicate(SqlExpressionFactory.Not(translation));
selectExpression.ReplaceProjection(new Dictionary<ProjectionMember, Expression>());
if (selectExpression.Limit == null
&& selectExpression.Offset == null)
{
selectExpression.ClearOrdering();
}

translation = _sqlExpressionFactory.Exists(selectExpression, true);
translation = SqlExpressionFactory.Exists(selectExpression, true);
source.QueryExpression = selectExpression.SetProjectionAsResult(translation);
source.ShaperExpression = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), typeof(bool));

Expand All @@ -79,7 +79,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour
selectExpression.ClearOrdering();
}

var translation = _sqlExpressionFactory.Exists(selectExpression, false);
var translation = SqlExpressionFactory.Exists(selectExpression, false);
source.QueryExpression = selectExpression.SetProjectionAsResult(translation);
source.ShaperExpression = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), typeof(bool));

Expand All @@ -106,21 +106,21 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression
if (inputType == typeof(int)
|| inputType == typeof(long))
{
projection = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(projection, typeof(double)));
projection = SqlExpressionFactory.ApplyDefaultTypeMapping(
SqlExpressionFactory.Convert(projection, typeof(double)));
}

if (inputType == typeof(float))
{
projection = _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
projection = SqlExpressionFactory.Convert(
SqlExpressionFactory.Function(
"AVG", new[] { projection }, typeof(double), null),
projection.Type,
projection.TypeMapping);
}
else
{
projection = _sqlExpressionFactory.Function(
projection = SqlExpressionFactory.Function(
"AVG", new[] { projection }, projection.Type, projection.TypeMapping);
}

Expand Down Expand Up @@ -155,7 +155,7 @@ protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression
}

selectExpression.ApplyProjection();
translation = _sqlExpressionFactory.In(translation, selectExpression, false);
translation = SqlExpressionFactory.In(translation, selectExpression, false);
source.QueryExpression = selectExpression.SetProjectionAsResult(translation);
source.ShaperExpression = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), typeof(bool));

Expand All @@ -181,12 +181,9 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
source = TranslateWhere(source, predicate);
}

var translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int)));

var _projectionMapping = new Dictionary<ProjectionMember, Expression>
{
{ new ProjectionMember(), translation }
{ new ProjectionMember(), GenerateCountExpression() }
};

selectExpression.ClearOrdering();
Expand All @@ -196,6 +193,10 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
return source;
}

protected virtual SqlExpression GenerateCountExpression()
=> SqlExpressionFactory.ApplyDefaultTypeMapping(
SqlExpressionFactory.Function("COUNT", new[] { SqlExpressionFactory.Fragment("*") }, typeof(long)));

protected override ShapedQueryExpression TranslateDefaultIfEmpty(ShapedQueryExpression source, Expression defaultValue) => throw new NotImplementedException();

protected override ShapedQueryExpression TranslateDistinct(ShapedQueryExpression source)
Expand Down Expand Up @@ -373,7 +374,7 @@ private SqlBinaryExpression CreateJoinPredicate(
{
result = result == null
? CreateJoinPredicate(outerNew.Arguments[i], innerNew.Arguments[i])
: _sqlExpressionFactory.AndAlso(
: SqlExpressionFactory.AndAlso(
result,
CreateJoinPredicate(outerNew.Arguments[i], innerNew.Arguments[i]));
}
Expand All @@ -393,7 +394,7 @@ private SqlBinaryExpression CreateJoinPredicate(

if (left != null && right != null)
{
return _sqlExpressionFactory.Equal(left, right);
return SqlExpressionFactory.Equal(left, right);
}

return null;
Expand Down Expand Up @@ -442,11 +443,9 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
source = TranslateWhere(source, predicate);
}

var translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
var _projectionMapping = new Dictionary<ProjectionMember, Expression>
{
{ new ProjectionMember(), translation }
{ new ProjectionMember(), GenerateLongCountExpression() }
};

selectExpression.ClearOrdering();
Expand All @@ -456,6 +455,10 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
return source;
}

protected virtual SqlExpression GenerateLongCountExpression()
=> SqlExpressionFactory.ApplyDefaultTypeMapping(
SqlExpressionFactory.Function("COUNT", new[] { SqlExpressionFactory.Fragment("*") }, typeof(long)));

protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
Expand All @@ -472,7 +475,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour

var projection = (SqlExpression)selectExpression.GetProjectionExpression(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);
projection = SqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand All @@ -493,7 +496,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour

var projection = (SqlExpression)selectExpression.GetProjectionExpression(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);
projection = SqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand Down Expand Up @@ -531,7 +534,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s

var newSelectorBody = ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

source.ShaperExpression = _projectionBindingExpressionVisitor
source.ShaperExpression = ProjectionBindingExpressionVisitor
.Translate((SelectExpression)source.QueryExpression, newSelectorBody);

return source;
Expand Down Expand Up @@ -685,14 +688,14 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour

if (serverOutputType == typeof(float))
{
projection = _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function("SUM", new[] { projection }, typeof(double)),
projection = SqlExpressionFactory.Convert(
SqlExpressionFactory.Function("SUM", new[] { projection }, typeof(double)),
serverOutputType,
projection.TypeMapping);
}
else
{
projection = _sqlExpressionFactory.Function(
projection = SqlExpressionFactory.Function(
"SUM", new[] { projection }, serverOutputType, projection.TypeMapping);
}

Expand Down Expand Up @@ -754,7 +757,7 @@ protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression so

private SqlExpression TranslateExpression(Expression expression)
{
return _sqlTranslator.Translate(expression);
return SqlTranslator.Translate(expression);
}

private SqlExpression TranslateLambdaExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public static IServiceCollection AddEntityFrameworkSqlServer([NotNull] this ISer
.TryAdd<IMethodCallTranslatorProvider, SqlServerMethodCallTranslatorProvider>()
.TryAdd<IMemberTranslatorProvider, SqlServerMemberTranslatorProvider>()
.TryAdd<IShapedQueryOptimizerFactory, SqlServerShapedQueryOptimizerFactory>()
.TryAdd<IQueryableMethodTranslatingExpressionVisitorFactory, SqlServerQueryableMethodTranslatingExpressionVisitorFactory>()
.TryAddProviderSpecificServices(
b => b
.TryAddSingleton<ISqlServerValueGeneratorCache, SqlServerValueGeneratorCache>()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline;
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Pipeline
{
public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor
{
public SqlServerQueryableMethodTranslatingExpressionVisitor(
IModel model,
IQueryableMethodTranslatingExpressionVisitorFactory queryableMethodTranslatingExpressionVisitorFactory,
IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory,
ISqlExpressionFactory sqlExpressionFactory)
: base(
model,
queryableMethodTranslatingExpressionVisitorFactory,
relationalSqlTranslatingExpressionVisitorFactory,
sqlExpressionFactory)
{
}

protected override SqlExpression GenerateLongCountExpression()
=> SqlExpressionFactory.ApplyDefaultTypeMapping(
SqlExpressionFactory.Function("COUNT_BIG", new[] { SqlExpressionFactory.Fragment("*") }, typeof(long)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Pipeline
{
public class SqlServerQueryableMethodTranslatingExpressionVisitorFactory : IQueryableMethodTranslatingExpressionVisitorFactory
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalSqlTranslatingExpressionVisitorFactory _relationalSqlTranslatingExpressionVisitorFactory;

public SqlServerQueryableMethodTranslatingExpressionVisitorFactory(
IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory,
ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
_relationalSqlTranslatingExpressionVisitorFactory = relationalSqlTranslatingExpressionVisitorFactory;
}

public QueryableMethodTranslatingExpressionVisitor Create(IModel model)
{
return new SqlServerQueryableMethodTranslatingExpressionVisitor(
model,
this,
_relationalSqlTranslatingExpressionVisitorFactory,
_sqlExpressionFactory);
}
}
}
12 changes: 6 additions & 6 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2657,7 +2657,7 @@ from o in os
select c.CustomerID);
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task SelectMany_LongCount(bool isAsync)
{
Expand Down Expand Up @@ -5103,7 +5103,7 @@ public virtual Task Select_orderBy_take_count(bool isAsync)
cs => cs.OrderBy(c => c.Country).Take(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_take_long_count(bool isAsync)
{
Expand All @@ -5112,7 +5112,7 @@ public virtual Task Select_take_long_count(bool isAsync)
cs => cs.Take(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_orderBy_take_long_count(bool isAsync)
{
Expand Down Expand Up @@ -5175,7 +5175,7 @@ public virtual Task Select_orderBy_skip_count(bool isAsync)
cs => cs.OrderBy(c => c.Country).Skip(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_skip_long_count(bool isAsync)
{
Expand All @@ -5184,7 +5184,7 @@ public virtual Task Select_skip_long_count(bool isAsync)
cs => cs.Skip(7));
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_orderBy_skip_long_count(bool isAsync)
{
Expand Down Expand Up @@ -5238,7 +5238,7 @@ public virtual Task Select_distinct_count(bool isAsync)
cs => cs.Distinct());
}

[ConditionalTheory(Skip = "Issue#15718")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_distinct_long_count(bool isAsync)
{
Expand Down

0 comments on commit f12c064

Please sign in to comment.