Skip to content

Commit

Permalink
Include/nav support for set operations
Browse files Browse the repository at this point in the history
Continues #6812
Fixes #13196
Fixes #16065
Fixes #16165
  • Loading branch information
roji committed Jul 1, 2019
1 parent 8efc3ad commit 686d56d
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
var collectionId = _collectionId++;
var selectExpression = (SelectExpression)collectionShaperExpression.Projection.QueryExpression;
// Do pushdown beforehand so it updates all pending collections first
if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null)
if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null || selectExpression.IsSetOperation)
{
selectExpression.PushdownIntoSubquery();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,10 +842,12 @@ private SqlBinaryExpression ValidateKeyComparison(SelectExpression inner, SqlBin
return null;
}

// We treat a set operation as a transparent wrapper over its left operand (the ColumnExpression projection mappings
// found on a set operation SelectExpression are actually those of its left operand).
private bool ContainsTableReference(TableExpressionBase table)
{
return _tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table));
}
=> IsSetOperation
? ((SelectExpression)Tables[0]).ContainsTableReference(table)
: Tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table));

public void AddInnerJoin(SelectExpression innerSelectExpression, SqlExpression joinPredicate, Type transparentIdentifierType)
{
Expand Down
6 changes: 6 additions & 0 deletions src/EFCore/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore/Properties/CoreStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -1186,4 +1186,7 @@
<data name="UnableToDiscriminate" xml:space="preserve">
<value>Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'.</value>
</data>
<data name="SetOperationWithDifferentIncludesInOperands" xml:space="preserve">
<value>When performing a set operation, both operands must have the same Include operations.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq.Expressions;
using System.Reflection;
using System.Xml;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
Expand Down Expand Up @@ -35,71 +36,85 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
: methodCallExpression.Update(methodCallExpression.Object, new[] { newSource, methodCallExpression.Arguments[1] });
}

switch (methodCallExpression.Method.Name)
if (methodCallExpression.Method.DeclaringType == typeof(Queryable)
|| methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions)
|| methodCallExpression.Method.DeclaringType == typeof(Enumerable)
|| methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions))
{
case nameof(Queryable.Where):
return ProcessWhere(methodCallExpression);
switch (methodCallExpression.Method.Name)
{
case nameof(Queryable.Where):
return ProcessWhere(methodCallExpression);

case nameof(Queryable.Select):
return ProcessSelect(methodCallExpression);
case nameof(Queryable.Select):
return ProcessSelect(methodCallExpression);

case nameof(Queryable.OrderBy):
case nameof(Queryable.OrderByDescending):
return ProcessOrderBy(methodCallExpression);
case nameof(Queryable.OrderBy):
case nameof(Queryable.OrderByDescending):
return ProcessOrderBy(methodCallExpression);

case nameof(Queryable.ThenBy):
case nameof(Queryable.ThenByDescending):
return ProcessThenByBy(methodCallExpression);
case nameof(Queryable.ThenBy):
case nameof(Queryable.ThenByDescending):
return ProcessThenByBy(methodCallExpression);

case nameof(Queryable.Join):
return ProcessJoin(methodCallExpression);
case nameof(Queryable.Join):
return ProcessJoin(methodCallExpression);

case nameof(Queryable.GroupJoin):
return ProcessGroupJoin(methodCallExpression);
case nameof(Queryable.GroupJoin):
return ProcessGroupJoin(methodCallExpression);

case nameof(Queryable.SelectMany):
return ProcessSelectMany(methodCallExpression);
case nameof(Queryable.SelectMany):
return ProcessSelectMany(methodCallExpression);

case nameof(Queryable.All):
return ProcessAll(methodCallExpression);
case nameof(Queryable.All):
return ProcessAll(methodCallExpression);

case nameof(Queryable.Any):
case nameof(Queryable.Count):
case nameof(Queryable.LongCount):
return ProcessAnyCountLongCount(methodCallExpression);
case nameof(Queryable.Any):
case nameof(Queryable.Count):
case nameof(Queryable.LongCount):
return ProcessAnyCountLongCount(methodCallExpression);

case nameof(Queryable.Average):
case nameof(Queryable.Sum):
case nameof(Queryable.Min):
case nameof(Queryable.Max):
return ProcessAverageSumMinMax(methodCallExpression);
case nameof(Queryable.Average):
case nameof(Queryable.Sum):
case nameof(Queryable.Min):
case nameof(Queryable.Max):
return ProcessAverageSumMinMax(methodCallExpression);

case nameof(Queryable.Distinct):
return ProcessDistinct(methodCallExpression);
case nameof(Queryable.Distinct):
return ProcessDistinct(methodCallExpression);

case nameof(Queryable.DefaultIfEmpty):
return ProcessDefaultIfEmpty(methodCallExpression);
case nameof(Queryable.DefaultIfEmpty):
return ProcessDefaultIfEmpty(methodCallExpression);

case nameof(Queryable.First):
case nameof(Queryable.FirstOrDefault):
case nameof(Queryable.Single):
case nameof(Queryable.SingleOrDefault):
return ProcessCardinalityReducingOperation(methodCallExpression);
case nameof(Queryable.First):
case nameof(Queryable.FirstOrDefault):
case nameof(Queryable.Single):
case nameof(Queryable.SingleOrDefault):
return ProcessCardinalityReducingOperation(methodCallExpression);

case nameof(Queryable.OfType):
return ProcessOfType(methodCallExpression);
case nameof(Queryable.OfType):
return ProcessOfType(methodCallExpression);

case nameof(Queryable.Skip):
case nameof(Queryable.Take):
return ProcessSkipTake(methodCallExpression);
case nameof(Queryable.Skip):
case nameof(Queryable.Take):
return ProcessSkipTake(methodCallExpression);

case "Include":
case "ThenInclude":
return ProcessInclude(methodCallExpression);
case nameof(Queryable.Union):
case nameof(Queryable.Concat):
case nameof(Queryable.Intersect):
case nameof(Queryable.Except):
return ProcessSetOperation(methodCallExpression);

default:
return ProcessUnknownMethod(methodCallExpression);
case "Include":
case "ThenInclude":
return ProcessInclude(methodCallExpression);

default:
return ProcessUnknownMethod(methodCallExpression);
}
}

return ProcessUnknownMethod(methodCallExpression);
}

private Expression ProcessUnknownMethod(MethodCallExpression methodCallExpression)
Expand Down Expand Up @@ -820,6 +835,51 @@ private Expression ProcessSkipTake(MethodCallExpression methodCallExpression)
return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type);
}

private Expression ProcessSetOperation(MethodCallExpression methodCallExpression)
{
// TODO: We shouldn't terminate if both sides are identical, #16246

var source1 = VisitSourceExpression(methodCallExpression.Arguments[0]);
var preProcessResult1 = PreProcessTerminatingOperation(source1);

var source2 = VisitSourceExpression(methodCallExpression.Arguments[1]);
var preProcessResult2 = PreProcessTerminatingOperation(source2);

// Extract the includes from each side and compare to make sure they're identical.
// We don't allow set operations over operands with different includes.
var pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false);
pendingIncludeFindingVisitor.Visit(preProcessResult1.state.PendingSelector.Body);
var pendingIncludes1 = pendingIncludeFindingVisitor.PendingIncludes;

pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false);
pendingIncludeFindingVisitor.Visit(preProcessResult2.state.PendingSelector.Body);
var pendingIncludes2 = pendingIncludeFindingVisitor.PendingIncludes;

if (pendingIncludes1.Count != pendingIncludes2.Count)
{
throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands);
}

foreach (var (i1, i2) in pendingIncludes1.Zip(pendingIncludes2, (i1, i2) => (i1, i2)))
{
if (i1.SourceMapping.RootEntityType != i2.SourceMapping.RootEntityType
|| i1.NavTreeNode.Navigation != i2.NavTreeNode.Navigation)
{
throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands);
}
}

// If the siblings are different types, one is derived from the other the set operation returns the less derived type.
// Find that.
var clrType1 = preProcessResult1.state.CurrentParameter.Type;
var clrType2 = preProcessResult2.state.CurrentParameter.Type;
var parentState = clrType1.IsAssignableFrom(clrType2) ? preProcessResult1.state : preProcessResult2.state;

var rewritten = methodCallExpression.Update(null, new[] { preProcessResult1.source, preProcessResult2.source });

return new NavigationExpansionExpression(rewritten, parentState, methodCallExpression.Type);
}

private (Expression source, NavigationExpansionExpressionState state) PreProcessTerminatingOperation(NavigationExpansionExpression source)
{
var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
result = NavigationExpansionHelpers.AddNavigationJoin(
result.source,
result.parameter,
pendingIncludeNode.Value,
pendingIncludeNode.Key,
pendingIncludeNode.SourceMapping,
pendingIncludeNode.NavTreeNode,
navigationExpansionExpression.State,
new List<INavigation>(),
include: true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors

public class PendingIncludeFindingVisitor : ExpressionVisitor
{
public virtual Dictionary<NavigationTreeNode, SourceMapping> PendingIncludes { get; } = new Dictionary<NavigationTreeNode, SourceMapping>();
private bool _skipCollectionNavigations;

public PendingIncludeFindingVisitor(bool skipCollectionNavigations = true)
{
_skipCollectionNavigations = skipCollectionNavigations;
}

public virtual List<(NavigationTreeNode NavTreeNode, SourceMapping SourceMapping)> PendingIncludes { get; } =
new List<(NavigationTreeNode, SourceMapping)>();

protected override Expression VisitMember(MemberExpression memberExpression)
{
Expand Down Expand Up @@ -81,14 +89,16 @@ protected override Expression VisitExtension(Expression extensionExpression)

private void FindPendingReferenceIncludes(NavigationTreeNode node, SourceMapping sourceMapping)
{
if (node.Navigation != null && node.Navigation.IsCollection())
if (_skipCollectionNavigations && node.Navigation != null && node.Navigation.IsCollection())
{
return;
}

if (node.Included == NavigationTreeNodeIncludeMode.ReferencePending && node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete)
if (node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete
&& (node.Included == NavigationTreeNodeIncludeMode.ReferencePending
|| !_skipCollectionNavigations && node.Included == NavigationTreeNodeIncludeMode.Collection))
{
PendingIncludes[node] = sourceMapping;
PendingIncludes.Add((node, sourceMapping));
}

foreach (var child in node.Children)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ public override void Union_non_entity(bool isAsync) {}
public override Task Union_with_anonymous_type_projection(bool isAsync) => Task.CompletedTask;
public override Task Select_Union_unrelated(bool isAsync) => Task.CompletedTask;
public override Task Select_Union_different_fields_in_anonymous_with_subquery(bool isAsync) => Task.CompletedTask;
public override Task Union_Include(bool isAsync) => Task.CompletedTask;
public override Task Include_Union(bool isAsync) => Task.CompletedTask;
public override Task Select_Except_reference_projection(bool isAsync) => Task.CompletedTask;
public override void Include_Union_only_on_one_side_throws() {}
public override void Include_Union_different_includes_throws() {}
public override Task SubSelect_Union(bool isAsync) => Task.CompletedTask;
public override Task Client_eval_Union_FirstOrDefault(bool isAsync) => Task.CompletedTask;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,26 @@ public virtual Task Select_Union_different_fields_in_anonymous_with_subquery(boo
.Where(x => x.Foo == "Berlin"),
entryCount: 1);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Union_Include(bool isAsync)
=> AssertQuery<Customer>(isAsync, cs => cs
.Where(c => c.City == "Berlin")
.Union(cs.Where(c => c.City == "London"))
.Include(c => c.Orders),
entryCount: 59);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Include_Union(bool isAsync)
=> AssertQuery<Customer>(isAsync, cs => cs
.Where(c => c.City == "Berlin")
.Include(c => c.Orders)
.Union(cs
.Where(c => c.City == "London")
.Include(c => c.Orders)),
entryCount: 59);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_Except_reference_projection(bool isAsync)
Expand All @@ -260,6 +280,45 @@ public virtual Task Select_Except_reference_projection(bool isAsync)
.Select(o => o.Customer)),
entryCount: 88);

[ConditionalFact]
public virtual void Include_Union_only_on_one_side_throws()
{
using (var ctx = CreateContext())
{
Assert.Throws<NotSupportedException>(() =>
ctx.Customers
.Where(c => c.City == "Berlin")
.Include(c => c.Orders)
.Union(ctx.Customers.Where(c => c.City == "London"))
.ToList());

Assert.Throws<NotSupportedException>(() =>
ctx.Customers
.Where(c => c.City == "Berlin")
.Union(ctx.Customers
.Where(c => c.City == "London")
.Include(c => c.Orders))
.ToList());
}
}

[ConditionalFact]
public virtual void Include_Union_different_includes_throws()
{
using (var ctx = CreateContext())
{
Assert.Throws<NotSupportedException>(() =>
ctx.Customers
.Where(c => c.City == "Berlin")
.Include(c => c.Orders)
.Union(ctx.Customers
.Where(c => c.City == "London")
.Include(c => c.Orders)
.ThenInclude(o => o.OrderDetails))
.ToList());
}
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task SubSelect_Union(bool isAsync)
Expand Down
Loading

0 comments on commit 686d56d

Please sign in to comment.