Skip to content

Commit

Permalink
Query: Lift GroupByAggregate when correlation predicate matches
Browse files Browse the repository at this point in the history
Earlier, we didn't match on exact predicate

Also fix bug in LikeExpression.Equals

Resolves #27102
  • Loading branch information
smitpatel committed Jan 5, 2022
1 parent 1a29900 commit 4546961
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,10 @@ private SqlExpression CombineGroupByAggregateTerms(SelectExpression selectExpres
selector = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(predicate, selector) },
elseResult: null);
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled))
{
selectExpression.UpdatePredicate(_groupingElementCorrelationalPredicate!);
}
}

if (selectExpression.IsDistinct)
Expand Down
4 changes: 3 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressions/LikeExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ private bool Equals(LikeExpression likeExpression)
=> base.Equals(likeExpression)
&& Match.Equals(likeExpression.Match)
&& Pattern.Equals(likeExpression.Pattern)
&& EscapeChar?.Equals(likeExpression.EscapeChar) == true;
&& (EscapeChar == null
? likeExpression.EscapeChar == null
: EscapeChar.Equals(likeExpression.EscapeChar));

/// <inheritdoc />
public override int GetHashCode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
var newOrderings = selectExpression._orderings.Select(Visit).ToList<OrderingExpression>();
var offset = (SqlExpression?)Visit(selectExpression.Offset);
var limit = (SqlExpression?)Visit(selectExpression.Limit);
var groupingCorrelationPredicate = (SqlExpression?)Visit(selectExpression._groupingCorrelationPredicate);

var newSelectExpression = new SelectExpression(
selectExpression.Alias, newProjections, newTables, newTableReferences, newGroupBy, newOrderings)
Expand All @@ -793,7 +794,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
IsDistinct = selectExpression.IsDistinct,
Tags = selectExpression.Tags,
_usedAliases = selectExpression._usedAliases.ToHashSet(),
_projectionMapping = newProjectionMappings
_projectionMapping = newProjectionMappings,
_groupingCorrelationPredicate = groupingCorrelationPredicate
};

newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables);
Expand Down Expand Up @@ -865,7 +867,9 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
if (subquery.Limit == null
&& subquery.Offset == null
&& subquery._groupBy.Count == 0
&& subquery.Predicate != null)
&& subquery.Predicate != null
&& ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)
|| subquery.Predicate.Equals(subquery._groupingCorrelationPredicate)))
{
var initialTableCounts = 0;
var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count);
Expand Down
19 changes: 18 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public sealed partial class SelectExpression : TableExpressionBase
private List<Expression> _clientProjections = new();
private readonly List<string?> _aliasForClientProjections = new();

private SqlExpression? _groupingCorrelationPredicate;
private CloningExpressionVisitor? _cloningExpressionVisitor;

private SelectExpression(
Expand Down Expand Up @@ -528,6 +529,7 @@ static void UpdateLimit(SelectExpression selectExpression)
|| shapedQueryExpression.ResultCardinality == ResultCardinality.SingleOrDefault:
{
var innerSelectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;
innerSelectExpression._groupingCorrelationPredicate = null;
var innerShaperExpression = shapedQueryExpression.ShaperExpression;
if (innerSelectExpression._clientProjections.Count == 0)
{
Expand Down Expand Up @@ -599,6 +601,7 @@ static Expression RemoveConvert(Expression expression)
when shapedQueryExpression.ResultCardinality == ResultCardinality.Enumerable:
{
var innerSelectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;
innerSelectExpression._groupingCorrelationPredicate = null;
if (_identifier.Count == 0
|| innerSelectExpression._identifier.Count == 0)
{
Expand Down Expand Up @@ -1144,6 +1147,11 @@ public void ApplyPredicate(SqlExpression sqlExpression)
}
}

internal void UpdatePredicate(SqlExpression predicate)
{
Predicate = predicate;
}

/// <summary>
/// Applies grouping from given key selector.
/// </summary>
Expand Down Expand Up @@ -1254,6 +1262,10 @@ public GroupByShaperExpression ApplyGrouping(
.Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r));
clonedSelectExpression._groupBy.Clear();
clonedSelectExpression.ApplyPredicate(correlationPredicate);
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled))
{
clonedSelectExpression._groupingCorrelationPredicate = clonedSelectExpression.Predicate;
}

if (!_identifier.All(e => _groupBy.Contains(e.Column)))
{
Expand Down Expand Up @@ -3354,6 +3366,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)

Offset = (SqlExpression?)visitor.Visit(Offset);
Limit = (SqlExpression?)visitor.Visit(Limit);
_groupingCorrelationPredicate = (SqlExpression?)visitor.Visit(_groupingCorrelationPredicate);

var identifier = VisitList(_identifier.Select(e => e.Column).ToList(), inPlace: true, out _)
.Zip(_identifier, (a, b) => (a, b.Comparer))
Expand Down Expand Up @@ -3432,6 +3445,9 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
var limit = (SqlExpression?)visitor.Visit(Limit);
changed |= limit != Limit;

var groupingCorrelationPredicate = (SqlExpression?)visitor.Visit(_groupingCorrelationPredicate);
changed |= groupingCorrelationPredicate != _groupingCorrelationPredicate;

var identifier = VisitList(_identifier.Select(e => e.Column).ToList(), inPlace: false, out var identifierChanged);
changed |= identifierChanged;

Expand All @@ -3453,7 +3469,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
Limit = limit,
IsDistinct = IsDistinct,
Tags = Tags,
_usedAliases = _usedAliases
_usedAliases = _usedAliases,
_groupingCorrelationPredicate = groupingCorrelationPredicate
};

newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,22 @@ public virtual Task GroupBy_with_aggregate_through_navigation_property(bool asyn
elementSorter: e => e.max);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_with_aggregate_containing_complex_where(bool async)
{
return AssertQuery(
async,
ss => from o in ss.Set<Order>()
group o.OrderID by o.EmployeeID into tg
select new
{
tg.Key,
Max = ss.Set<Order>().Where(e => e.EmployeeID == tg.Max() * 6).Max(t => (int?)t.OrderID)
},
elementSorter: e => e.Key);
}

#endregion

#region GroupByAnonymousAggregate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,25 @@ FROM [Orders] AS [o]
GROUP BY [o].[EmployeeID]");
}

public override async Task GroupBy_with_aggregate_containing_complex_where(bool async)
{
await base.GroupBy_with_aggregate_containing_complex_where(async);

AssertSql(
@"SELECT [o].[EmployeeID] AS [Key], (
SELECT MAX([o0].[OrderID])
FROM [Orders] AS [o0]
WHERE (CAST([o0].[EmployeeID] AS bigint) = CAST(((
SELECT MAX([o1].[OrderID])
FROM [Orders] AS [o1]
WHERE ([o].[EmployeeID] = [o1].[EmployeeID]) OR ([o].[EmployeeID] IS NULL AND [o1].[EmployeeID] IS NULL)) * 6) AS bigint)) OR ([o0].[EmployeeID] IS NULL AND (
SELECT MAX([o1].[OrderID])
FROM [Orders] AS [o1]
WHERE ([o].[EmployeeID] = [o1].[EmployeeID]) OR ([o].[EmployeeID] IS NULL AND [o1].[EmployeeID] IS NULL)) IS NULL)) AS [Max]
FROM [Orders] AS [o]
GROUP BY [o].[EmployeeID]");
}

public override async Task GroupBy_Shadow(bool async)
{
await base.GroupBy_Shadow(async);
Expand Down

0 comments on commit 4546961

Please sign in to comment.