From 686d56d12f2f87441e369caae7d7b6509780051f Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 25 Jun 2019 14:02:00 +0200 Subject: [PATCH] Include/nav support for set operations Continues #6812 Fixes #13196 Fixes #16065 Fixes #16165 --- ...ShapedQueryOptimizingExpressionVisitors.cs | 2 +- .../SqlExpressions/SelectExpression.cs | 8 +- src/EFCore/Properties/CoreStrings.Designer.cs | 6 + src/EFCore/Properties/CoreStrings.resx | 3 + .../NavigationExpandingVisitor_MethodCall.cs | 154 ++++++++++++------ .../NavigationExpansionReducingVisitor.cs | 4 +- .../Visitors/PendingIncludeFindingVisitor.cs | 18 +- .../SimpleQueryCosmosTest.SetOperations.cs | 4 + .../SimpleQueryTestBase.SetOperations.cs | 59 +++++++ .../SimpleQuerySqlServerTest.SetOperations.cs | 38 +++++ 10 files changed, 239 insertions(+), 57 deletions(-) diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs index 44aae76545c..ab750c726a8 100644 --- a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs +++ b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs @@ -51,7 +51,7 @@ protected override Expression VisitExtension(Expression extensionExpression) var collectionId = _collectionId++; var selectExpression = (SelectExpression)collectionShaperExpression.Projection.QueryExpression; // Do pushdown beforehand so it updates all pending collections first - if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null) + if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null || selectExpression.IsSetOperation) { selectExpression.PushdownIntoSubquery(); } diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs index e170d8531a0..a555efe1ede 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs @@ -842,10 +842,12 @@ private SqlBinaryExpression ValidateKeyComparison(SelectExpression inner, SqlBin return null; } + // We treat a set operation as a transparent wrapper over its left operand (the ColumnExpression projection mappings + // found on a set operation SelectExpression are actually those of its left operand). private bool ContainsTableReference(TableExpressionBase table) - { - return _tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table)); - } + => IsSetOperation + ? ((SelectExpression)Tables[0]).ContainsTableReference(table) + : Tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table)); public void AddInnerJoin(SelectExpression innerSelectExpression, SqlExpression joinPredicate, Type transparentIdentifierType) { diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index a6a242d418d..52c1179dff9 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2148,6 +2148,12 @@ public static string UnableToDiscriminate([CanBeNull] object entityType, [CanBeN GetString("UnableToDiscriminate", nameof(entityType), nameof(discriminator)), entityType, discriminator); + /// + /// When performing a set operation, both operands must have the same Include operations. + /// + public static string SetOperationWithDifferentIncludesInOperands + => GetString("SetOperationWithDifferentIncludesInOperands"); + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name); diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index cd4066e6ad0..1e2ccbe89ab 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1186,4 +1186,7 @@ Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'. + + When performing a set operation, both operands must have the same Include operations. + diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs index fa633a0e870..202d3153f04 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs @@ -7,6 +7,7 @@ using System.Linq.Expressions; using System.Reflection; using System.Xml; +using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; @@ -35,71 +36,85 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp : methodCallExpression.Update(methodCallExpression.Object, new[] { newSource, methodCallExpression.Arguments[1] }); } - switch (methodCallExpression.Method.Name) + if (methodCallExpression.Method.DeclaringType == typeof(Queryable) + || methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions) + || methodCallExpression.Method.DeclaringType == typeof(Enumerable) + || methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)) { - case nameof(Queryable.Where): - return ProcessWhere(methodCallExpression); + switch (methodCallExpression.Method.Name) + { + case nameof(Queryable.Where): + return ProcessWhere(methodCallExpression); - case nameof(Queryable.Select): - return ProcessSelect(methodCallExpression); + case nameof(Queryable.Select): + return ProcessSelect(methodCallExpression); - case nameof(Queryable.OrderBy): - case nameof(Queryable.OrderByDescending): - return ProcessOrderBy(methodCallExpression); + case nameof(Queryable.OrderBy): + case nameof(Queryable.OrderByDescending): + return ProcessOrderBy(methodCallExpression); - case nameof(Queryable.ThenBy): - case nameof(Queryable.ThenByDescending): - return ProcessThenByBy(methodCallExpression); + case nameof(Queryable.ThenBy): + case nameof(Queryable.ThenByDescending): + return ProcessThenByBy(methodCallExpression); - case nameof(Queryable.Join): - return ProcessJoin(methodCallExpression); + case nameof(Queryable.Join): + return ProcessJoin(methodCallExpression); - case nameof(Queryable.GroupJoin): - return ProcessGroupJoin(methodCallExpression); + case nameof(Queryable.GroupJoin): + return ProcessGroupJoin(methodCallExpression); - case nameof(Queryable.SelectMany): - return ProcessSelectMany(methodCallExpression); + case nameof(Queryable.SelectMany): + return ProcessSelectMany(methodCallExpression); - case nameof(Queryable.All): - return ProcessAll(methodCallExpression); + case nameof(Queryable.All): + return ProcessAll(methodCallExpression); - case nameof(Queryable.Any): - case nameof(Queryable.Count): - case nameof(Queryable.LongCount): - return ProcessAnyCountLongCount(methodCallExpression); + case nameof(Queryable.Any): + case nameof(Queryable.Count): + case nameof(Queryable.LongCount): + return ProcessAnyCountLongCount(methodCallExpression); - case nameof(Queryable.Average): - case nameof(Queryable.Sum): - case nameof(Queryable.Min): - case nameof(Queryable.Max): - return ProcessAverageSumMinMax(methodCallExpression); + case nameof(Queryable.Average): + case nameof(Queryable.Sum): + case nameof(Queryable.Min): + case nameof(Queryable.Max): + return ProcessAverageSumMinMax(methodCallExpression); - case nameof(Queryable.Distinct): - return ProcessDistinct(methodCallExpression); + case nameof(Queryable.Distinct): + return ProcessDistinct(methodCallExpression); - case nameof(Queryable.DefaultIfEmpty): - return ProcessDefaultIfEmpty(methodCallExpression); + case nameof(Queryable.DefaultIfEmpty): + return ProcessDefaultIfEmpty(methodCallExpression); - case nameof(Queryable.First): - case nameof(Queryable.FirstOrDefault): - case nameof(Queryable.Single): - case nameof(Queryable.SingleOrDefault): - return ProcessCardinalityReducingOperation(methodCallExpression); + case nameof(Queryable.First): + case nameof(Queryable.FirstOrDefault): + case nameof(Queryable.Single): + case nameof(Queryable.SingleOrDefault): + return ProcessCardinalityReducingOperation(methodCallExpression); - case nameof(Queryable.OfType): - return ProcessOfType(methodCallExpression); + case nameof(Queryable.OfType): + return ProcessOfType(methodCallExpression); - case nameof(Queryable.Skip): - case nameof(Queryable.Take): - return ProcessSkipTake(methodCallExpression); + case nameof(Queryable.Skip): + case nameof(Queryable.Take): + return ProcessSkipTake(methodCallExpression); - case "Include": - case "ThenInclude": - return ProcessInclude(methodCallExpression); + case nameof(Queryable.Union): + case nameof(Queryable.Concat): + case nameof(Queryable.Intersect): + case nameof(Queryable.Except): + return ProcessSetOperation(methodCallExpression); - default: - return ProcessUnknownMethod(methodCallExpression); + case "Include": + case "ThenInclude": + return ProcessInclude(methodCallExpression); + + default: + return ProcessUnknownMethod(methodCallExpression); + } } + + return ProcessUnknownMethod(methodCallExpression); } private Expression ProcessUnknownMethod(MethodCallExpression methodCallExpression) @@ -820,6 +835,51 @@ private Expression ProcessSkipTake(MethodCallExpression methodCallExpression) return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type); } + private Expression ProcessSetOperation(MethodCallExpression methodCallExpression) + { + // TODO: We shouldn't terminate if both sides are identical, #16246 + + var source1 = VisitSourceExpression(methodCallExpression.Arguments[0]); + var preProcessResult1 = PreProcessTerminatingOperation(source1); + + var source2 = VisitSourceExpression(methodCallExpression.Arguments[1]); + var preProcessResult2 = PreProcessTerminatingOperation(source2); + + // Extract the includes from each side and compare to make sure they're identical. + // We don't allow set operations over operands with different includes. + var pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false); + pendingIncludeFindingVisitor.Visit(preProcessResult1.state.PendingSelector.Body); + var pendingIncludes1 = pendingIncludeFindingVisitor.PendingIncludes; + + pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false); + pendingIncludeFindingVisitor.Visit(preProcessResult2.state.PendingSelector.Body); + var pendingIncludes2 = pendingIncludeFindingVisitor.PendingIncludes; + + if (pendingIncludes1.Count != pendingIncludes2.Count) + { + throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands); + } + + foreach (var (i1, i2) in pendingIncludes1.Zip(pendingIncludes2, (i1, i2) => (i1, i2))) + { + if (i1.SourceMapping.RootEntityType != i2.SourceMapping.RootEntityType + || i1.NavTreeNode.Navigation != i2.NavTreeNode.Navigation) + { + throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands); + } + } + + // If the siblings are different types, one is derived from the other the set operation returns the less derived type. + // Find that. + var clrType1 = preProcessResult1.state.CurrentParameter.Type; + var clrType2 = preProcessResult2.state.CurrentParameter.Type; + var parentState = clrType1.IsAssignableFrom(clrType2) ? preProcessResult1.state : preProcessResult2.state; + + var rewritten = methodCallExpression.Update(null, new[] { preProcessResult1.source, preProcessResult2.source }); + + return new NavigationExpansionExpression(rewritten, parentState, methodCallExpression.Type); + } + private (Expression source, NavigationExpansionExpressionState state) PreProcessTerminatingOperation(NavigationExpansionExpression source) { var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State); diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs index 828d4ccc9c9..0648f4d476f 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs @@ -120,8 +120,8 @@ protected override Expression VisitExtension(Expression extensionExpression) result = NavigationExpansionHelpers.AddNavigationJoin( result.source, result.parameter, - pendingIncludeNode.Value, - pendingIncludeNode.Key, + pendingIncludeNode.SourceMapping, + pendingIncludeNode.NavTreeNode, navigationExpansionExpression.State, new List(), include: true); diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs index ffcbde22396..963432490d5 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs @@ -10,7 +10,15 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors public class PendingIncludeFindingVisitor : ExpressionVisitor { - public virtual Dictionary PendingIncludes { get; } = new Dictionary(); + private bool _skipCollectionNavigations; + + public PendingIncludeFindingVisitor(bool skipCollectionNavigations = true) + { + _skipCollectionNavigations = skipCollectionNavigations; + } + + public virtual List<(NavigationTreeNode NavTreeNode, SourceMapping SourceMapping)> PendingIncludes { get; } = + new List<(NavigationTreeNode, SourceMapping)>(); protected override Expression VisitMember(MemberExpression memberExpression) { @@ -81,14 +89,16 @@ protected override Expression VisitExtension(Expression extensionExpression) private void FindPendingReferenceIncludes(NavigationTreeNode node, SourceMapping sourceMapping) { - if (node.Navigation != null && node.Navigation.IsCollection()) + if (_skipCollectionNavigations && node.Navigation != null && node.Navigation.IsCollection()) { return; } - if (node.Included == NavigationTreeNodeIncludeMode.ReferencePending && node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete) + if (node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete + && (node.Included == NavigationTreeNodeIncludeMode.ReferencePending + || !_skipCollectionNavigations && node.Included == NavigationTreeNodeIncludeMode.Collection)) { - PendingIncludes[node] = sourceMapping; + PendingIncludes.Add((node, sourceMapping)); } foreach (var child in node.Children) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs index 2b8a8f0cf36..2f9fd79f91f 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs @@ -33,7 +33,11 @@ public override void Union_non_entity(bool isAsync) {} public override Task Union_with_anonymous_type_projection(bool isAsync) => Task.CompletedTask; public override Task Select_Union_unrelated(bool isAsync) => Task.CompletedTask; public override Task Select_Union_different_fields_in_anonymous_with_subquery(bool isAsync) => Task.CompletedTask; + public override Task Union_Include(bool isAsync) => Task.CompletedTask; + public override Task Include_Union(bool isAsync) => Task.CompletedTask; public override Task Select_Except_reference_projection(bool isAsync) => Task.CompletedTask; + public override void Include_Union_only_on_one_side_throws() {} + public override void Include_Union_different_includes_throws() {} public override Task SubSelect_Union(bool isAsync) => Task.CompletedTask; public override Task Client_eval_Union_FirstOrDefault(bool isAsync) => Task.CompletedTask; } diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs index 241d2c10f43..caf56aa4569 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs @@ -250,6 +250,26 @@ public virtual Task Select_Union_different_fields_in_anonymous_with_subquery(boo .Where(x => x.Foo == "Berlin"), entryCount: 1); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_Include(bool isAsync) + => AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .Include(c => c.Orders), + entryCount: 59); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Include_Union(bool isAsync) + => AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Include(c => c.Orders) + .Union(cs + .Where(c => c.City == "London") + .Include(c => c.Orders)), + entryCount: 59); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Select_Except_reference_projection(bool isAsync) @@ -260,6 +280,45 @@ public virtual Task Select_Except_reference_projection(bool isAsync) .Select(o => o.Customer)), entryCount: 88); + [ConditionalFact] + public virtual void Include_Union_only_on_one_side_throws() + { + using (var ctx = CreateContext()) + { + Assert.Throws(() => + ctx.Customers + .Where(c => c.City == "Berlin") + .Include(c => c.Orders) + .Union(ctx.Customers.Where(c => c.City == "London")) + .ToList()); + + Assert.Throws(() => + ctx.Customers + .Where(c => c.City == "Berlin") + .Union(ctx.Customers + .Where(c => c.City == "London") + .Include(c => c.Orders)) + .ToList()); + } + } + + [ConditionalFact] + public virtual void Include_Union_different_includes_throws() + { + using (var ctx = CreateContext()) + { + Assert.Throws(() => + ctx.Customers + .Where(c => c.City == "Berlin") + .Include(c => c.Orders) + .Union(ctx.Customers + .Where(c => c.City == "London") + .Include(c => c.Orders) + .ThenInclude(o => o.OrderDetails)) + .ToList()); + } + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task SubSelect_Union(bool isAsync) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs index b7e9b957e81..3bdf17787a8 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs @@ -264,6 +264,44 @@ OFFSET @__p_0 ROWS FETCH NEXT @__p_1 ROWS ONLY ORDER BY [t0].[Foo]"); } + public override async Task Union_Include(bool isAsync) + { + await base.Union_Include(isAsync); + + AssertSql( + @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM ( + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL + UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +) AS [t] +LEFT JOIN [Orders] AS [o] ON [t].[CustomerID] = [o].[CustomerID] +ORDER BY [t].[CustomerID], [o].[OrderID]"); + } + + public override async Task Include_Union(bool isAsync) + { + await base.Include_Union(isAsync); + + AssertSql( + @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM ( + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL + UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +) AS [t] +LEFT JOIN [Orders] AS [o] ON [t].[CustomerID] = [o].[CustomerID] +ORDER BY [t].[CustomerID], [o].[OrderID]"); + } + public override async Task Select_Except_reference_projection(bool isAsync) { await base.Select_Except_reference_projection(isAsync);