From 7fe751b62aca9ae5aacf1b2df921c3445e0b512b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 27 Jun 2019 17:50:34 +0200 Subject: [PATCH] Fixed detection of differing includes in set operands --- .../SqlExpressions/SelectExpression.cs | 5 ++- .../NavigationExpandingVisitor_MethodCall.cs | 37 ++++++++++--------- .../NavigationExpansionReducingVisitor.cs | 4 +- .../Visitors/PendingIncludeFindingVisitor.cs | 18 +++++++-- 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs index 549ea967a06..3fbc4d649cf 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs @@ -565,7 +565,10 @@ static Expression UpdateEntityShaperEntityType(Expression shaperExpression, IEnt case EntityShaperExpression entityShaperExpression: return new EntityShaperExpression(newEntityType, entityShaperExpression.ValueBufferExpression, entityShaperExpression.Nullable); case UnaryExpression unary when unary.NodeType == ExpressionType.Convert: - return Convert(UpdateEntityShaperEntityType(unary.Operand, newEntityType), unary.Type); + var newShaperExpression = UpdateEntityShaperEntityType(unary.Operand, newEntityType); + return newShaperExpression.Type == unary.Type + ? newShaperExpression + : Convert(newShaperExpression, unary.Type); case UnaryExpression unary when unary.NodeType == ExpressionType.ConvertChecked: return ConvertChecked(UpdateEntityShaperEntityType(unary.Operand, newEntityType), unary.Type); default: diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs index 01c9eef6ede..ddc9d80c346 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs @@ -874,33 +874,34 @@ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression var source2 = VisitSourceExpression(methodCallExpression.Arguments[1]); var preProcessResult2 = PreProcessTerminatingOperation(source2); - // Compare the include chains from each side to make sure they're identical. We don't allow set operations over - // operands with different include chains. - var current1 = preProcessResult1.state.PendingIncludeChain?.NavigationTreeNode; - var current2 = preProcessResult2.state.PendingIncludeChain?.NavigationTreeNode; - while (true) + // Extract the includes from each side and compare to make sure they're identical. + // We don't allow set operations over operands with different includes. + var pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(false); + pendingIncludeFindingVisitor.Visit(preProcessResult1.state.PendingSelector.Body); + var pendingIncludes1 = pendingIncludeFindingVisitor.PendingIncludes; + + pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(false); + pendingIncludeFindingVisitor.Visit(preProcessResult2.state.PendingSelector.Body); + var pendingIncludes2 = pendingIncludeFindingVisitor.PendingIncludes; + + if (pendingIncludes1.Count != pendingIncludes2.Count) { - if (current1 == null) - { - if (current2 == null) - { - break; - } - throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands); - } + throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands); + } - if (current2 == null) + foreach (var (i1, i2) in pendingIncludes1.Zip(pendingIncludes2, (i1, i2) => (i1, i2))) + { + if (i1.SourceMapping.RootEntityType != i2.SourceMapping.RootEntityType) { throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands); } - if (current1.FromMappings.Zip(current2.FromMappings, (m1, m2) => (m1, m2)) - .Any(t => !t.m1.SequenceEqual(t.m2))) + if (i1.NavTreeNode.Flatten() + .Zip(i2.NavTreeNode.Flatten(), (n1, n2) => (First: n1, Second: n2)) + .Any(nodes => nodes.First.Navigation != nodes.Second.Navigation)) { throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands); } - - (current1, current2) = (current1.Parent, current2.Parent); } // If the siblings are different types, one is derived from the other the set operation returns the less derived type. diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs index c0f601d449d..ec29f0556dd 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs @@ -139,8 +139,8 @@ protected override Expression VisitExtension(Expression extensionExpression) result = NavigationExpansionHelpers.AddNavigationJoin( result.source, result.parameter, - pendingIncludeNode.Value, - pendingIncludeNode.Key, + pendingIncludeNode.SourceMapping, + pendingIncludeNode.NavTreeNode, navigationExpansionExpression.State, new List(), include: true); diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs index ae4dd3e703b..5c3b0b7a039 100644 --- a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs +++ b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs @@ -10,7 +10,15 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors public class PendingIncludeFindingVisitor : ExpressionVisitor { - public virtual Dictionary PendingIncludes { get; } = new Dictionary(); + private bool _skipCollectionNavigations; + + public PendingIncludeFindingVisitor(bool skipCollectionNavigations = true) + { + _skipCollectionNavigations = skipCollectionNavigations; + } + + public virtual List<(NavigationTreeNode NavTreeNode, SourceMapping SourceMapping)> PendingIncludes { get; } = + new List<(NavigationTreeNode, SourceMapping)>(); protected override Expression VisitMember(MemberExpression memberExpression) => memberExpression; protected override Expression VisitInvocation(InvocationExpression invocationExpression) => invocationExpression; @@ -69,14 +77,16 @@ protected override Expression VisitExtension(Expression extensionExpression) private void FindPendingReferenceIncludes(NavigationTreeNode node, SourceMapping sourceMapping) { - if (node.Navigation != null && node.Navigation.IsCollection()) + if (_skipCollectionNavigations && node.Navigation != null && node.Navigation.IsCollection()) { return; } - if (node.Included == NavigationTreeNodeIncludeMode.ReferencePending && node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete) + if (node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete + && (node.Included == NavigationTreeNodeIncludeMode.ReferencePending + || !_skipCollectionNavigations && node.Included == NavigationTreeNodeIncludeMode.Collection)) { - PendingIncludes[node] = sourceMapping; + PendingIncludes.Add((node, sourceMapping)); } foreach (var child in node.Children)