Skip to content

Commit

Permalink
Query: Fix #4588 - Query :: SelectMany-GroupJoin-DefaultIfEmpty isn't…
Browse files Browse the repository at this point in the history
… being lifted into SQL for some complex queries, resulting in extensive client-side evaluation.

Adds relational GroupJoin/DefaultIfEmpty elimination.
  • Loading branch information
anpete committed Sep 23, 2016
1 parent 847a266 commit 7278626
Show file tree
Hide file tree
Showing 17 changed files with 627 additions and 364 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ var structuralComparisonExpression
var leftExpression = Visit(expression.Left);
var rightExpression = Visit(expression.Right);

return leftExpression != null && rightExpression != null
return leftExpression != null
&& rightExpression != null
? Expression.MakeBinary(
expression.NodeType,
leftExpression,
Expand Down Expand Up @@ -252,6 +253,7 @@ protected override Expression VisitConditional(ConditionalExpression expression)
var ifTrue = Visit(expression.IfTrue);
var ifFalse = Visit(expression.IfFalse);

// ReSharper disable once ConditionIsAlwaysTrueOrFalse
if (test != null
&& ifTrue != null
&& ifFalse != null)
Expand Down Expand Up @@ -991,9 +993,21 @@ protected override Expression VisitExtension(Expression expression)
var nullConditionalExpression
= expression as NullConditionalExpression;

return nullConditionalExpression != null
? Visit(nullConditionalExpression.AccessOperation)
: base.VisitExtension(expression);
if (nullConditionalExpression != null)
{
var newAccessOperation = Visit(nullConditionalExpression.AccessOperation);

if (newAccessOperation != null
&& newAccessOperation.Type != nullConditionalExpression.Type)
{
newAccessOperation
= Expression.Convert(newAccessOperation, nullConditionalExpression.Type);
}

return newAccessOperation;
}

return base.VisitExtension(expression);
}

/// <summary>
Expand All @@ -1007,6 +1021,28 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr
{
Check.NotNull(expression, nameof(expression));

if (!_inProjection)
{
var joinClause
= expression.ReferencedQuerySource as JoinClause;

if (joinClause != null)
{
var entityType
= _queryModelVisitor.QueryCompilationContext.Model
.FindEntityType(joinClause.ItemType);

if (entityType != null)
{
return Visit(
EntityQueryModelVisitor.CreatePropertyExpression(
expression, entityType.FindPrimaryKey().Properties[0]));
}

return null;
}
}

var selector
= ((expression.ReferencedQuerySource as FromClauseBase)
?.FromExpression as SubQueryExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,14 +476,26 @@ public virtual int AddToProjection([NotNull] Expression expression)
public virtual int AddToProjection([NotNull] Expression expression, bool resetProjectStar)
{
Check.NotNull(expression, nameof(expression));

if (expression.NodeType == ExpressionType.Convert)
{
var unaryExpression = (UnaryExpression)expression;

if (unaryExpression.Type.UnwrapNullableType()
== unaryExpression.Operand.Type)
{
expression = unaryExpression.Operand;
}
}

var columnExpression = expression as ColumnExpression;
var aliasExpression = expression as AliasExpression;

if (columnExpression != null)
{
return AddToProjection(columnExpression);
}

var aliasExpression = expression as AliasExpression;

if (aliasExpression != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,10 +680,11 @@ private static Expression HandleSkip(HandlerContext handlerContext)

private static Expression HandleSum(HandlerContext handlerContext)
{
if (!handlerContext.QueryModelVisitor.RequiresClientProjection)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
var sumExpression
= new SumExpression(handlerContext.SelectExpression.Projection.Single());
= new SumExpression(handlerContext.SelectExpression.Projection.First());

handlerContext.SelectExpression.SetProjectionExpression(sumExpression);

Expand All @@ -699,12 +700,14 @@ private static Expression HandleTake(HandlerContext handlerContext)
{
var takeResultOperator = (TakeResultOperator)handlerContext.ResultOperator;

var sqlTranslatingExpressionVisitor = handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true);
var sqlTranslatingExpressionVisitor
= handlerContext.CreateSqlTranslatingVisitor(bindParentQueries: true);

var limit = sqlTranslatingExpressionVisitor.Visit(takeResultOperator.Count);

if (limit != null)
{
handlerContext.SelectExpression.Limit = takeResultOperator.Count;
handlerContext.SelectExpression.Limit = limit;

return handlerContext.EvalOnServer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ExpressionVisitors;
using Remotion.Linq.Clauses.ResultOperators;

namespace Microsoft.EntityFrameworkCore.Query
{
Expand Down Expand Up @@ -243,7 +245,7 @@ public virtual void RegisterSubQueryVisitor(
Check.NotNull(querySource, nameof(querySource));
Check.NotNull(queryModelVisitor, nameof(queryModelVisitor));

_subQueryModelVisitorsBySource.Add(querySource, queryModelVisitor);
_subQueryModelVisitorsBySource[querySource] = queryModelVisitor;
}

/// <summary>
Expand Down Expand Up @@ -341,12 +343,13 @@ protected override void IncludeNavigations(
Check.NotNull(includeSpecification, nameof(includeSpecification));
Check.NotNull(resultType, nameof(resultType));

var includeExpressionVisitor = _includeExpressionVisitorFactory.Create(
includeSpecification.QuerySource,
includeSpecification.NavigationPath,
QueryCompilationContext,
_navigationIndexMap[includeSpecification],
querySourceRequiresTracking);
var includeExpressionVisitor
= _includeExpressionVisitorFactory.Create(
includeSpecification.QuerySource,
includeSpecification.NavigationPath,
QueryCompilationContext,
_navigationIndexMap[includeSpecification],
querySourceRequiresTracking);

Expression = includeExpressionVisitor.Visit(Expression);
}
Expand Down Expand Up @@ -510,15 +513,12 @@ var fromQuerySourceReferenceExpression
private bool CanFlattenSelectMany()
{
var selectManyExpression = Expression as MethodCallExpression;
if (selectManyExpression == null
|| !selectManyExpression.Method.MethodIsClosedFormOf(LinqOperatorProvider.SelectMany)
|| !IsShapedQueryExpression(selectManyExpression.Arguments[0] as MethodCallExpression, innerShapedQuery: false)
|| !IsShapedQueryExpression((selectManyExpression.Arguments[1] as LambdaExpression)?.Body as MethodCallExpression, innerShapedQuery: true))
{
return false;
}

return true;
return selectManyExpression != null
&& selectManyExpression.Method.MethodIsClosedFormOf(LinqOperatorProvider.SelectMany)
&& IsShapedQueryExpression(selectManyExpression.Arguments[0] as MethodCallExpression, innerShapedQuery: false)
&& IsShapedQueryExpression((selectManyExpression.Arguments[1] as LambdaExpression)
?.Body as MethodCallExpression, innerShapedQuery: true);
}

private bool IsShapedQueryExpression(MethodCallExpression shapedQueryExpression, bool innerShapedQuery)
Expand All @@ -538,18 +538,14 @@ private bool IsShapedQueryExpression(MethodCallExpression shapedQueryExpression,
}
}

if (shapedQueryExpression == null || shapedQueryExpression.Arguments.Count != 3)
if (shapedQueryExpression.Arguments.Count != 3)
{
return false;
}

var shaper = shapedQueryExpression.Arguments[2] as ConstantExpression;
if (shaper == null || !(shaper.Value is Shaper))
{
return false;
}

return true;
return shaper?.Value is Shaper;
}

/// <summary>
Expand Down Expand Up @@ -627,7 +623,7 @@ public override void VisitGroupJoinClause(
index,
() => base.VisitGroupJoinClause(groupJoinClause, queryModel, index),
LinqOperatorProvider.GroupJoin,
outerJoin: true);
groupJoin: true);
}

/// <summary>
Expand All @@ -638,14 +634,14 @@ public override void VisitGroupJoinClause(
/// <param name="index"> Index of the node being visited. </param>
/// <param name="baseVisitAction"> The base visit action. </param>
/// <param name="operatorToFlatten"> The operator to flatten. </param>
/// <param name="outerJoin"> true if an outer join should be performed. </param>
/// <param name="groupJoin"> true if an outer join should be performed. </param>
protected virtual void OptimizeJoinClause(
[NotNull] JoinClause joinClause,
[NotNull] QueryModel queryModel,
int index,
[NotNull] Action baseVisitAction,
[NotNull] MethodInfo operatorToFlatten,
bool outerJoin = false)
bool groupJoin = false)
{
Check.NotNull(joinClause, nameof(joinClause));
Check.NotNull(queryModel, nameof(queryModel));
Expand All @@ -664,11 +660,14 @@ var previousSelectExpression
var previousSelectProjectionCount
= previousSelectExpression?.Projection.Count ?? -1;

var previousParameter = CurrentParameter;
var previousMapping = SnapshotQuerySourceMapping(queryModel);

baseVisitAction();

if (!RequiresClientSelectMany
&& previousSelectExpression != null
&& (!operatorToFlatten.MethodIsClosedFormOf(LinqOperatorProvider.GroupJoin)
&& (!groupJoin
|| CanFlattenGroupJoin()))
{
var selectExpression = TryGetQuery(joinClause);
Expand Down Expand Up @@ -700,13 +699,13 @@ var projection
: Enumerable.Empty<Expression>();

var joinExpression
= !outerJoin
= !groupJoin
? previousSelectExpression.AddInnerJoin(tableExpression, projection)
: previousSelectExpression.AddLeftOuterJoin(tableExpression, projection);

joinExpression.Predicate = predicate;

if (outerJoin)
if (groupJoin)
{
var outerJoinOrderingExtractor = new OuterJoinOrderingExtractor();

Expand All @@ -717,6 +716,61 @@ var joinExpression
previousSelectExpression
.AddToOrderBy(new Ordering(expression, OrderingDirection.Asc));
}

var additionalFromClause
= queryModel.BodyClauses.ElementAtOrDefault(index + 1)
as AdditionalFromClause;

var subQueryModel
= (additionalFromClause?.FromExpression as SubQueryExpression)?.QueryModel;

if (subQueryModel != null
&& subQueryModel.ResultOperators.Count == 1
&& subQueryModel.ResultOperators[0] is DefaultIfEmptyResultOperator)
{
var groupJoinClause
= (subQueryModel.MainFromClause.FromExpression as QuerySourceReferenceExpression)
?.ReferencedQuerySource as GroupJoinClause;

if (groupJoinClause?.JoinClause == joinClause
&& queryModel.CountQuerySourceReferences(groupJoinClause) == 1)
{
queryModel.BodyClauses.RemoveAt(index + 1);

var querySourceMapping = new QuerySourceMapping();

querySourceMapping.AddMapping(
additionalFromClause,
new QuerySourceReferenceExpression(joinClause));

queryModel.TransformExpressions(e =>
ReferenceReplacingExpressionVisitor
.ReplaceClauseReferences(
e,
querySourceMapping,
throwOnUnmappedReferences: false));

Expression = ((MethodCallExpression)Expression).Arguments[0];

CurrentParameter = previousParameter;

foreach (var mapping in previousMapping)
{
QueryCompilationContext.QuerySourceMapping
.ReplaceMapping(mapping.Key, mapping.Value);
}

var previousProjectionCount = previousSelectExpression.Projection.Count;

base.VisitJoinClause(joinClause, queryModel, index);

previousSelectExpression.RemoveRangeFromProjection(previousProjectionCount);

QueriesBySource.Remove(joinClause);

operatorToFlatten = LinqOperatorProvider.Join;
}
}
}

Expression
Expand All @@ -739,18 +793,40 @@ var joinExpression
}
}

private bool CanFlattenGroupJoin()
private Dictionary<IQuerySource, Expression> SnapshotQuerySourceMapping(QueryModel queryModel)
{
var groupJoinExpression = Expression as MethodCallExpression;
if (groupJoinExpression == null
|| !groupJoinExpression.Method.MethodIsClosedFormOf(LinqOperatorProvider.GroupJoin)
|| !IsShapedQueryExpression(groupJoinExpression.Arguments[0] as MethodCallExpression, innerShapedQuery: false)
|| !IsShapedQueryExpression(groupJoinExpression.Arguments[1] as MethodCallExpression, innerShapedQuery: true))
var previousMapping
= new Dictionary<IQuerySource, Expression>
{
{
queryModel.MainFromClause,
QueryCompilationContext.QuerySourceMapping
.GetExpression(queryModel.MainFromClause)
}
};

foreach (var querySource in queryModel.BodyClauses.OfType<IQuerySource>())
{
return false;
if (QueryCompilationContext.QuerySourceMapping.ContainsMapping(querySource))
{
previousMapping.Add(
querySource,
QueryCompilationContext.QuerySourceMapping
.GetExpression(querySource));
}
}

return true;
return previousMapping;
}

private bool CanFlattenGroupJoin()
{
var groupJoinExpression = Expression as MethodCallExpression;

return groupJoinExpression != null
&& groupJoinExpression.Method.MethodIsClosedFormOf(LinqOperatorProvider.GroupJoin)
&& IsShapedQueryExpression(groupJoinExpression.Arguments[0] as MethodCallExpression, innerShapedQuery: false)
&& IsShapedQueryExpression(groupJoinExpression.Arguments[1] as MethodCallExpression, innerShapedQuery: true);
}

private class OuterJoinOrderingExtractor : ExpressionVisitor
Expand Down
Loading

0 comments on commit 7278626

Please sign in to comment.