From 79d5039b6ede3586050663d2f93577f1f81855d2 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Mon, 10 Jun 2019 15:23:02 -0700 Subject: [PATCH] Query: Encapsulate more logic about pushdown into SelectExpression --- ...yableMethodTranslatingExpressionVisitor.cs | 140 ++---------------- .../SqlExpressions/SelectExpression.cs | 129 +++++++++++++--- 2 files changed, 127 insertions(+), 142 deletions(-) diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 67fb7a5b9a6..9332ddda9cd 100644 --- a/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -59,17 +59,11 @@ public override ShapedQueryExpression TranslateSubquery(Expression expression) protected override ShapedQueryExpression TranslateAll(ShapedQueryExpression source, LambdaExpression predicate) { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } - var translation = TranslateLambdaExpression(source, predicate); if (translation != null) { + var selectExpression = (SelectExpression)source.QueryExpression; selectExpression.ApplyPredicate(_sqlExpressionFactory.Not(translation)); selectExpression.ReplaceProjectionMapping(new Dictionary()); if (selectExpression.Limit == null @@ -113,11 +107,7 @@ protected override ShapedQueryExpression TranslateAny(ShapedQueryExpression sour protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression source, LambdaExpression selector, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } + selectExpression.PrepareForAggregate(); if (selector != null) { @@ -192,13 +182,7 @@ protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression source, LambdaExpression predicate) { var selectExpression = (SelectExpression)source.QueryExpression; - - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } + selectExpression.PrepareForAggregate(); if (predicate != null) { @@ -307,19 +291,6 @@ protected override ShapedQueryExpression TranslateJoin( LambdaExpression innerKeySelector, LambdaExpression resultSelector) { - // TODO: write a test which has distinct on outer so that we can verify pushdown - var innerSelectExpression = (SelectExpression)inner.QueryExpression; - if (innerSelectExpression.Orderings.Any() - || innerSelectExpression.Limit != null - || innerSelectExpression.Offset != null - || innerSelectExpression.IsDistinct - // TODO: Predicate can be lifted in inner join - || innerSelectExpression.Predicate != null - || innerSelectExpression.Tables.Count > 1) - { - innerSelectExpression.PushdownIntoSubQuery(); - } - var joinPredicate = CreateJoinPredicate(outer, outerKeySelector, inner, innerKeySelector); if (joinPredicate != null) { @@ -328,7 +299,7 @@ protected override ShapedQueryExpression TranslateJoin( resultSelector.Parameters[1].Type); ((SelectExpression)outer.QueryExpression).AddInnerJoin( - innerSelectExpression, joinPredicate, transparentIdentifierType); + (SelectExpression)inner.QueryExpression, joinPredicate, transparentIdentifierType); return TranslateResultSelectorForJoin( outer, @@ -343,25 +314,6 @@ protected override ShapedQueryExpression TranslateJoin( protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) { - var outerSelectExpression = (SelectExpression)outer.QueryExpression; - if (outerSelectExpression.Limit != null - || outerSelectExpression.Offset != null - || outerSelectExpression.IsDistinct) - { - outerSelectExpression.PushdownIntoSubQuery(); - } - - var innerSelectExpression = (SelectExpression)inner.QueryExpression; - if (innerSelectExpression.Orderings.Any() - || innerSelectExpression.Limit != null - || innerSelectExpression.Offset != null - || innerSelectExpression.IsDistinct - || innerSelectExpression.Predicate != null - || innerSelectExpression.Tables.Count > 1) - { - innerSelectExpression.PushdownIntoSubQuery(); - } - var joinPredicate = CreateJoinPredicate(outer, outerKeySelector, inner, innerKeySelector); if (joinPredicate != null) { @@ -369,8 +321,8 @@ protected override ShapedQueryExpression TranslateLeftJoin(ShapedQueryExpression resultSelector.Parameters[0].Type, resultSelector.Parameters[1].Type); - outerSelectExpression.AddLeftJoin( - innerSelectExpression, joinPredicate, transparentIdentifierType); + ((SelectExpression)outer.QueryExpression).AddLeftJoin( + (SelectExpression)inner.QueryExpression, joinPredicate, transparentIdentifierType); return TranslateResultSelectorForJoin( outer, @@ -447,13 +399,6 @@ protected override ShapedQueryExpression TranslateLastOrDefault( } var selectExpression = (SelectExpression)source.QueryExpression; - - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } - selectExpression.ReverseOrderings(); selectExpression.ApplyLimit(TranslateExpression(Expression.Constant(1))); @@ -468,13 +413,7 @@ protected override ShapedQueryExpression TranslateLastOrDefault( protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpression source, LambdaExpression predicate) { var selectExpression = (SelectExpression)source.QueryExpression; - - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } + selectExpression.PrepareForAggregate(); if (predicate != null) { @@ -498,11 +437,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression source, LambdaExpression selector, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } + selectExpression.PrepareForAggregate(); if (selector != null) { @@ -519,11 +454,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression source, LambdaExpression selector, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } + selectExpression.PrepareForAggregate(); if (selector != null) { @@ -600,19 +531,10 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s protected override ShapedQueryExpression TranslateOrderBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending) { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } - var translation = TranslateLambdaExpression(source, keySelector); - if (translation != null) { - selectExpression.ApplyOrdering(new OrderingExpression(translation, ascending)); + ((SelectExpression)source.QueryExpression).ApplyOrdering(new OrderingExpression(translation, ascending)); return source; } @@ -632,7 +554,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s var selectExpression = (SelectExpression)source.QueryExpression; if (selectExpression.IsDistinct) { - selectExpression.PushdownIntoSubQuery(); + selectExpression.PushdownIntoSubquery(); } var newSelectorBody = ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body); @@ -669,31 +591,11 @@ protected override ShapedQueryExpression TranslateSelectMany(ShapedQueryExpressi { if (Visit(collectionSelectorBody) is ShapedQueryExpression inner) { - var outerSelectExpression = (SelectExpression)source.QueryExpression; - if (outerSelectExpression.Limit != null - || outerSelectExpression.Offset != null - || outerSelectExpression.IsDistinct - || outerSelectExpression.Predicate != null) - { - outerSelectExpression.PushdownIntoSubQuery(); - } - - var innerSelectExpression = (SelectExpression)inner.QueryExpression; - if (innerSelectExpression.Orderings.Any() - || innerSelectExpression.Limit != null - || innerSelectExpression.Offset != null - || innerSelectExpression.IsDistinct - || innerSelectExpression.Predicate != null) - { - innerSelectExpression.PushdownIntoSubQuery(); - } - var transparentIdentifierType = CreateTransparentIdentifierType( resultSelector.Parameters[0].Type, resultSelector.Parameters[1].Type); - - outerSelectExpression.AddCrossJoin( - innerSelectExpression, transparentIdentifierType); + ((SelectExpression)source.QueryExpression).AddCrossJoin( + (SelectExpression)inner.QueryExpression, transparentIdentifierType); return TranslateResultSelectorForJoin( source, @@ -775,11 +677,7 @@ protected override ShapedQueryExpression TranslateSkip(ShapedQueryExpression sou protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression source, LambdaExpression selector, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } + selectExpression.PrepareForAggregate(); if (selector != null) { @@ -825,7 +723,6 @@ protected override ShapedQueryExpression TranslateTake(ShapedQueryExpression sou protected override ShapedQueryExpression TranslateThenBy(ShapedQueryExpression source, LambdaExpression keySelector, bool ascending) { var translation = TranslateLambdaExpression(source, keySelector); - if (translation != null) { ((SelectExpression)source.QueryExpression).AppendOrdering(new OrderingExpression(translation, ascending)); @@ -840,17 +737,10 @@ protected override ShapedQueryExpression TranslateThenBy(ShapedQueryExpression s protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression source, LambdaExpression predicate) { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.Limit != null - || selectExpression.Offset != null) - { - selectExpression.PushdownIntoSubQuery(); - } - var translation = TranslateLambdaExpression(source, predicate); if (translation != null) { - selectExpression.ApplyPredicate(translation); + ((SelectExpression)source.QueryExpression).ApplyPredicate(translation); return source; } diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs index efb6bb388e9..5a6ce7ff6ab 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs @@ -205,6 +205,16 @@ public IDictionary AddToProjection(EntityProjectionExpression en return dictionary; } + public void PrepareForAggregate() + { + if (IsDistinct + || Limit != null + || Offset != null) + { + PushdownIntoSubquery(); + } + } + public void ApplyPredicate(SqlExpression expression) { if (expression is SqlConstantExpression sqlConstant @@ -213,6 +223,13 @@ public void ApplyPredicate(SqlExpression expression) return; } + if (Limit != null + || Offset != null) + { + var mappings = PushdownIntoSubquery(); + expression = new SqlRemappingVisitor(mappings).Remap(expression); + } + if (Predicate == null) { Predicate = expression; @@ -233,6 +250,15 @@ public void ApplyPredicate(SqlExpression expression) public void ApplyOrdering(OrderingExpression orderingExpression) { + if (IsDistinct + || Limit != null + || Offset != null) + { + orderingExpression = orderingExpression.Update( + new SqlRemappingVisitor(PushdownIntoSubquery()) + .Remap(orderingExpression.Expression)); + } + _orderings.Clear(); _orderings.Add(orderingExpression); } @@ -249,7 +275,7 @@ public void ApplyLimit(SqlExpression sqlExpression) { if (Limit != null) { - PushdownIntoSubQuery(); + PushdownIntoSubquery(); } Limit = sqlExpression; @@ -260,7 +286,7 @@ public void ApplyOffset(SqlExpression sqlExpression) if (Limit != null || Offset != null) { - PushdownIntoSubQuery(); + PushdownIntoSubquery(); } Offset = sqlExpression; @@ -268,6 +294,12 @@ public void ApplyOffset(SqlExpression sqlExpression) public void ReverseOrderings() { + if (Limit != null + || Offset != null) + { + PushdownIntoSubquery(); + } + var existingOrdering = _orderings.ToArray(); _orderings.Clear(); @@ -286,7 +318,7 @@ public void ApplyDistinct() if (Limit != null || Offset != null) { - PushdownIntoSubQuery(); + PushdownIntoSubquery(); } IsDistinct = true; @@ -298,7 +330,7 @@ public void ClearOrdering() _orderings.Clear(); } - public SelectExpression PushdownIntoSubQuery() + public IDictionary PushdownIntoSubquery() { var subquery = new SelectExpression("t", new List(), _tables.ToList(), _orderings.ToList()) { @@ -393,7 +425,7 @@ public SelectExpression PushdownIntoSubQuery() _tables.Clear(); _tables.Add(subquery); - return subquery; + return projectionMap; } private static bool IsNullableProjection(ProjectionExpression projection) @@ -445,8 +477,7 @@ public RelationalCollectionShaperExpression ApplyCollectionJoin(int collectionId || Limit != null || Offset != null) { - var subquery = PushdownIntoSubQuery(); - outer = LiftFromSubquery(subquery, outer); + outer = new SqlRemappingVisitor(PushdownIntoSubquery()).Remap(outer); } if (innerSelectExpression.Offset != null @@ -455,8 +486,8 @@ public RelationalCollectionShaperExpression ApplyCollectionJoin(int collectionId || innerSelectExpression.Predicate != null || innerSelectExpression.Tables.Count > 1) { - var subquery = innerSelectExpression.PushdownIntoSubQuery(); - inner = LiftFromSubquery(subquery, inner); + inner = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery()) + .Remap(inner); } var leftJoinExpression = new LeftJoinExpression(innerSelectExpression.Tables.Single(), @@ -608,15 +639,21 @@ private static bool ContainsTableReference(SelectExpression selectExpression, Ta return selectExpression.Tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table)); } - private ColumnExpression LiftFromSubquery(SelectExpression subquery, SqlExpression column) - { - var subqueryProjection = subquery._projection.Single(pe => pe.Expression.Equals(column)); - - return new ColumnExpression(subqueryProjection, subquery, IsNullableProjection(subqueryProjection)); - } - public void AddInnerJoin(SelectExpression innerSelectExpression, SqlExpression joinPredicate, Type transparentIdentifierType) { + // TODO: write a test which has distinct on outer so that we can verify pushdown + if (innerSelectExpression.Orderings.Any() + || innerSelectExpression.Limit != null + || innerSelectExpression.Offset != null + || innerSelectExpression.IsDistinct + // TODO: Predicate can be lifted in inner join + || innerSelectExpression.Predicate != null + || innerSelectExpression.Tables.Count > 1) + { + joinPredicate = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery()) + .Remap(joinPredicate); + } + _identifyingProjection.AddRange(innerSelectExpression._identifyingProjection); var joinTable = new InnerJoinExpression(innerSelectExpression.Tables.Single(), joinPredicate); _tables.Add(joinTable); @@ -639,6 +676,25 @@ public void AddInnerJoin(SelectExpression innerSelectExpression, SqlExpression j public void AddLeftJoin(SelectExpression innerSelectExpression, SqlExpression joinPredicate, Type transparentIdentifierType) { + if (Limit != null + || Offset != null + || IsDistinct) + { + joinPredicate = new SqlRemappingVisitor(PushdownIntoSubquery()) + .Remap(joinPredicate); + } + + if (innerSelectExpression.Orderings.Any() + || innerSelectExpression.Limit != null + || innerSelectExpression.Offset != null + || innerSelectExpression.IsDistinct + || innerSelectExpression.Predicate != null + || innerSelectExpression.Tables.Count > 1) + { + joinPredicate = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery()) + .Remap(joinPredicate); + } + var joinTable = new LeftJoinExpression(innerSelectExpression.Tables.Single(), joinPredicate); _tables.Add(joinTable); @@ -670,6 +726,23 @@ public void AddLeftJoin(SelectExpression innerSelectExpression, SqlExpression jo public void AddCrossJoin(SelectExpression innerSelectExpression, Type transparentIdentifierType) { + if (Limit != null + || Offset != null + || IsDistinct + || Predicate != null) + { + PushdownIntoSubquery(); + } + + if (innerSelectExpression.Orderings.Any() + || innerSelectExpression.Limit != null + || innerSelectExpression.Offset != null + || innerSelectExpression.IsDistinct + || innerSelectExpression.Predicate != null) + { + innerSelectExpression.PushdownIntoSubquery(); + } + _identifyingProjection.AddRange(innerSelectExpression._identifyingProjection); var joinTable = new CrossJoinExpression(innerSelectExpression.Tables.Single()); _tables.Add(joinTable); @@ -690,6 +763,29 @@ public void AddCrossJoin(SelectExpression innerSelectExpression, Type transparen _projectionMapping = projectionMapping; } + private class SqlRemappingVisitor : ExpressionVisitor + { + private readonly IDictionary _mappings; + + public SqlRemappingVisitor(IDictionary mappings) + { + _mappings = mappings; + } + + public SqlExpression Remap(SqlExpression sqlExpression) => (SqlExpression)Visit(sqlExpression); + + public override Expression Visit(Expression expression) + { + if (expression is SqlExpression sqlExpression + && _mappings.TryGetValue(sqlExpression, out var outer)) + { + return outer; + } + + return base.Visit(expression); + } + } + protected override Expression VisitChildren(ExpressionVisitor visitor) { var changed = false; @@ -876,7 +972,6 @@ public override int GetHashCode() return hashCode; } } - public override void Print(ExpressionPrinter expressionPrinter) { expressionPrinter.StringBuilder.AppendLine("Projection Mapping:");