Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Convert lateral joins into predicate joins only when appropriate #17182

Merged
merged 1 commit into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -722,19 +722,11 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
protected override ShapedQueryExpression TranslateSelectMany(
ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector)
{
var defaultIfEmpty = false;
if (collectionSelector.Body is MethodCallExpression collectionEndingMethod
&& collectionEndingMethod.Method.IsGenericMethod
&& collectionEndingMethod.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument)
{
defaultIfEmpty = true;
collectionSelector = Expression.Lambda(collectionEndingMethod.Arguments[0], collectionSelector.Parameters);
}

var correlated = new CorrelationFindingExpressionVisitor().IsCorrelated(collectionSelector);
var (newCollectionSelector, correlated, defaultIfEmpty)
= new CorrelationFindingExpressionVisitor().IsCorrelated(collectionSelector);
if (correlated)
{
var collectionSelectorBody = RemapLambdaBody(source, collectionSelector);
var collectionSelectorBody = RemapLambdaBody(source, newCollectionSelector);
if (Visit(collectionSelectorBody) is ShapedQueryExpression inner)
{
var transparentIdentifierType = TransparentIdentifierFactory.Create(
Expand Down Expand Up @@ -763,7 +755,7 @@ protected override ShapedQueryExpression TranslateSelectMany(
}
else
{
if (Visit(collectionSelector.Body) is ShapedQueryExpression inner)
if (Visit(newCollectionSelector.Body) is ShapedQueryExpression inner)
{
if (defaultIfEmpty)
{
Expand Down Expand Up @@ -791,28 +783,43 @@ protected override ShapedQueryExpression TranslateSelectMany(
private class CorrelationFindingExpressionVisitor : ExpressionVisitor
{
private ParameterExpression _outerParameter;
private bool _isCorrelated;
private bool _correlated;
private bool _defaultIfEmpty;

public bool IsCorrelated(LambdaExpression lambdaExpression)
public (LambdaExpression, bool, bool) IsCorrelated(LambdaExpression lambdaExpression)
{
Debug.Assert(lambdaExpression.Parameters.Count == 1, "Multiparameter lambda passed to CorrelationFindingExpressionVisitor");
_isCorrelated = false;

_correlated = false;
_defaultIfEmpty = false;
_outerParameter = lambdaExpression.Parameters[0];

Visit(lambdaExpression.Body);
var result = Visit(lambdaExpression.Body);

return _isCorrelated;
return (Expression.Lambda(result, _outerParameter), _correlated, _defaultIfEmpty);
}

protected override Expression VisitParameter(ParameterExpression parameterExpression)
{
if (parameterExpression == _outerParameter)
{
_isCorrelated = true;
_correlated = true;
}

return base.VisitParameter(parameterExpression);
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument)
{
_defaultIfEmpty = true;
return Visit(methodCallExpression.Arguments[0]);
}

return base.VisitMethodCall(methodCallExpression);
}
}

protected override ShapedQueryExpression TranslateSelectMany(ShapedQueryExpression source, LambdaExpression selector)
Expand Down
139 changes: 96 additions & 43 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -758,53 +758,58 @@ public Expression ApplyCollectionJoin(
}

var joinPredicate = TryExtractJoinKey(innerSelectExpression);
if (joinPredicate != null)
{
if (innerSelectExpression.Offset != null
|| innerSelectExpression.Limit != null
|| innerSelectExpression.IsDistinct
|| innerSelectExpression.Predicate != null
|| innerSelectExpression.Tables.Count > 1
|| innerSelectExpression.GroupBy.Count > 1)
{
var sqlRemappingVisitor = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery(),
(SelectExpression)innerSelectExpression.Tables[0]);
joinPredicate = sqlRemappingVisitor.Remap(joinPredicate);
}

var leftJoinExpression = new LeftJoinExpression(innerSelectExpression.Tables.Single(), joinPredicate);
_tables.Add(leftJoinExpression);
var containsOuterReference = new SelectExpressionCorrelationFindingExpressionVisitor(Tables)
.ContainsOuterReference(innerSelectExpression);
if (containsOuterReference && joinPredicate != null)
{
innerSelectExpression.ApplyPredicate(joinPredicate);
joinPredicate = null;
}

foreach (var ordering in innerSelectExpression.Orderings)
{
AppendOrdering(ordering.Update(MakeNullable(ordering.Expression)));
}
if (innerSelectExpression.Offset != null
|| innerSelectExpression.Limit != null
|| innerSelectExpression.IsDistinct
|| innerSelectExpression.Predicate != null
|| innerSelectExpression.Tables.Count > 1
|| innerSelectExpression.GroupBy.Count > 1)
{
var sqlRemappingVisitor = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery(),
(SelectExpression)innerSelectExpression.Tables[0]);
joinPredicate = sqlRemappingVisitor.Remap(joinPredicate);
}

var indexOffset = _projection.Count;
foreach (var projection in innerSelectExpression.Projection)
{
AddToProjection(MakeNullable(projection.Expression));
}
var joinExpression = joinPredicate == null
? (TableExpressionBase)new LeftJoinLateralExpression(innerSelectExpression.Tables.Single())
: new LeftJoinExpression(innerSelectExpression.Tables.Single(), joinPredicate);
_tables.Add(joinExpression);

foreach (var identifier in innerSelectExpression._identifier.Concat(innerSelectExpression._childIdentifiers))
{
var updatedColumn = MakeNullable(identifier);
_childIdentifiers.Add(updatedColumn);
AppendOrdering(new OrderingExpression(updatedColumn, ascending: true));
}
foreach (var ordering in innerSelectExpression.Orderings)
{
AppendOrdering(ordering.Update(MakeNullable(ordering.Expression)));
}

var shaperRemapper = new ShaperRemappingExpressionVisitor(this, innerSelectExpression, indexOffset);
innerShaper = shaperRemapper.Visit(innerShaper);
selfIdentifier = shaperRemapper.Visit(selfIdentifier);
var indexOffset = _projection.Count;
foreach (var projection in innerSelectExpression.Projection)
{
AddToProjection(MakeNullable(projection.Expression));
}

return new RelationalCollectionShaperExpression(
collectionId, parentIdentifier, outerIdentifier, selfIdentifier, innerShaper, navigation, elementType);
foreach (var identifier in innerSelectExpression._identifier.Concat(innerSelectExpression._childIdentifiers))
{
var updatedColumn = MakeNullable(identifier);
_childIdentifiers.Add(updatedColumn);
AppendOrdering(new OrderingExpression(updatedColumn, ascending: true));
}

throw new InvalidOperationException("CollectionJoin: Unable to identify correlation predicate to convert to Left Join");
var shaperRemapper = new ShaperRemappingExpressionVisitor(this, innerSelectExpression, indexOffset);
innerShaper = shaperRemapper.Visit(innerShaper);
selfIdentifier = shaperRemapper.Visit(selfIdentifier);

return new RelationalCollectionShaperExpression(
collectionId, parentIdentifier, outerIdentifier, selfIdentifier, innerShaper, navigation, elementType);
}

private SqlExpression MakeNullable(SqlExpression sqlExpression)
private static SqlExpression MakeNullable(SqlExpression sqlExpression)
=> sqlExpression is ColumnExpression column ? column.MakeNullable() : sqlExpression;

private Expression GetIdentifierAccessor(IEnumerable<SqlExpression> identifyingProjection)
Expand Down Expand Up @@ -879,7 +884,9 @@ private object GetProjectionIndex(ProjectionBindingExpression projectionBindingE

private SqlExpression TryExtractJoinKey(SelectExpression selectExpression)
{
if (selectExpression.Predicate != null)
if (selectExpression.Limit == null
&& selectExpression.Offset == null
&& selectExpression.Predicate != null)
{
var joinPredicate = TryExtractJoinKey(selectExpression, selectExpression.Predicate, out var predicate);
selectExpression.Predicate = predicate;
Expand Down Expand Up @@ -959,6 +966,44 @@ private bool ContainsTableReference(TableExpressionBase table)
? ((SelectExpression)Tables[0]).ContainsTableReference(table)
: Tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table));

private class SelectExpressionCorrelationFindingExpressionVisitor : ExpressionVisitor
{
private readonly IReadOnlyList<TableExpressionBase> _tables;
private bool _containsOuterReference;

public SelectExpressionCorrelationFindingExpressionVisitor(IReadOnlyList<TableExpressionBase> tables)
{
_tables = tables;
}

public bool ContainsOuterReference(SelectExpression selectExpression)
{
_containsOuterReference = false;

Visit(selectExpression);

return _containsOuterReference;
}

public override Expression Visit(Expression expression)
{
if (_containsOuterReference)
{
return expression;
}

if (expression is ColumnExpression columnExpression
&& _tables.Contains(columnExpression.Table))
{
_containsOuterReference = true;

return expression;
}

return base.Visit(expression);
}
}

private enum JoinType
{
InnerJoin,
Expand All @@ -978,12 +1023,20 @@ private void AddJoin(
if (joinType == JoinType.InnerJoinLateral || joinType == JoinType.LeftJoinLateral)
{
joinPredicate = TryExtractJoinKey(innerSelectExpression);
// TODO: Make sure that innerSelectExpression does not contain any reference from this SelectExpression
if (joinPredicate != null)
{
AddJoin(joinType == JoinType.InnerJoinLateral ? JoinType.InnerJoin : JoinType.LeftJoin,
innerSelectExpression, transparentIdentifierType, joinPredicate);
return;
var containsOuterReference = new SelectExpressionCorrelationFindingExpressionVisitor(Tables)
.ContainsOuterReference(innerSelectExpression);
if (containsOuterReference)
{
innerSelectExpression.ApplyPredicate(joinPredicate);
}
else
{
AddJoin(joinType == JoinType.InnerJoinLateral ? JoinType.InnerJoin : JoinType.LeftJoin,
innerSelectExpression, transparentIdentifierType, joinPredicate);
return;
}
}
}

Expand Down
24 changes: 24 additions & 0 deletions test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4084,6 +4084,30 @@ public override Task Multiple_select_many_with_predicate(bool isAsync)
return base.Multiple_select_many_with_predicate(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_1(bool isAsync)
{
return base.SelectMany_correlated_with_outer_1(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_2(bool isAsync)
{
return base.SelectMany_correlated_with_outer_2(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_3(bool isAsync)
{
return base.SelectMany_correlated_with_outer_3(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_4(bool isAsync)
{
return base.SelectMany_correlated_with_outer_4(isAsync);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,30 @@ public override Task SelectMany_without_result_selector_collection_navigation_co
return base.SelectMany_without_result_selector_collection_navigation_composed(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_1(bool isAsync)
{
return base.SelectMany_correlated_with_outer_1(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_2(bool isAsync)
{
return base.SelectMany_correlated_with_outer_2(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_3(bool isAsync)
{
return base.SelectMany_correlated_with_outer_3(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_4(bool isAsync)
{
return base.SelectMany_correlated_with_outer_4(isAsync);
}

#endregion
}
}
14 changes: 7 additions & 7 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4893,7 +4893,7 @@ orderby w.IsAutomatic
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_inner_subquery_selector_references_outer_qsre(bool isAsync)
{
Expand All @@ -4919,7 +4919,7 @@ from o in gs.OfType<Officer>()
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_inner_subquery_predicate_references_outer_qsre(bool isAsync)
{
Expand All @@ -4945,7 +4945,7 @@ from o in gs.OfType<Officer>()
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_nested_inner_subquery_references_outer_qsre_one_level_up(bool isAsync)
{
Expand Down Expand Up @@ -4984,7 +4984,7 @@ from o in gs.OfType<Officer>()
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_nested_inner_subquery_references_outer_qsre_two_levels_up(bool isAsync)
{
Expand Down Expand Up @@ -5729,7 +5729,7 @@ public virtual Task Where_required_navigation_on_derived_type(bool isAsync)
lls => lls.Where(ll => ll is LocustCommander ? ((LocustCommander)ll).HighCommand.IsOperational : false));
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Outer_parameter_in_join_key(bool isAsync)
{
Expand All @@ -5748,7 +5748,7 @@ join g in gs on o.FullName equals g.FullName
elementAsserter: (e, a) => CollectionAsserter<string>(elementSorter: ee => ee)(e.Collection, a.Collection));
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Outer_parameter_in_join_key_inner_and_outer(bool isAsync)
{
Expand Down Expand Up @@ -5786,7 +5786,7 @@ join g in gs on o.FullName equals g.FullName into grouping
elementAsserter: (e, a) => CollectionAsserter<string>(elementSorter: ee => ee)(e.Collection, a.Collection));
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Outer_parameter_in_group_join_with_DefaultIfEmpty(bool isAsync)
{
Expand Down
Loading