Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Nov 27, 2019
1 parent 54620ca commit c1513db
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 48 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions src/EFCore.Relational/Properties/RelationalStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,7 @@
<data name="PendingAmbientTransaction" xml:space="preserve">
<value>This connection was used with an ambient transaction. The original ambient transaction needs to be completed before this connection can be used outside of it.</value>
</data>
<data name="SetOperationNotWithinEntityTypeHierarchy" xml:space="preserve">
<value>Set operations (Union, Concat, Intersect, Except) are only supported over entity types within the same type hierarchy.</value>
</data>
<data name="FromSqlNonComposable" xml:space="preserve">
<value>FromSqlRaw or FromSqlInterpolated was called with non-composable SQL and with a query composing over it. Consider calling `AsEnumerable` after the FromSqlRaw or FromSqlInterpolated method to perform the composition on the client side.</value>
</data>
</root>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -1185,19 +1185,37 @@ private ShapedQueryExpression AggregateResultShaper(
}

/// <summary>
/// If a set operation is between different entity types, the query will return their closest common ancestor.
/// Modify the shaper accordingly.
/// If a set operation is between different entity types, the query will return their closest common ancestor.
/// Modify the shaper accordingly.
/// </summary>
private void ModifyShaperForSetOperation(ShapedQueryExpression source1, ShapedQueryExpression source2)
{
if (RemoveConvert(source1.ShaperExpression) is EntityShaperExpression shaper1
&& RemoveConvert(source2.ShaperExpression) is EntityShaperExpression shaper2
&& shaper1.EntityType != shaper2.EntityType)
{
var closestCommonParent = shaper1.EntityType.GetClosestCommonParent(shaper2.EntityType);

source1.ShaperExpression = new EntityShaperExpression(
shaper1.EntityType.GetClosestCommonParent(shaper2.EntityType),
closestCommonParent,
shaper1.ValueBufferExpression,
shaper1.IsNullable);

// If there's a convert node on either side (set operation over different entity type) and it's
// converting to a higher type in the hierarchy, add back a convert node to that type.
var convertType =
source1.ShaperExpression is UnaryExpression unary1
&& unary1.NodeType == ExpressionType.Convert
? unary1.Type
: source2.ShaperExpression is UnaryExpression unary2
&& unary2.NodeType == ExpressionType.Convert
? unary2.Type
: null;

if (convertType != null && convertType != closestCommonParent.ClrType)
{
source1.ShaperExpression = Expression.Convert(source1.ShaperExpression, convertType);
}
}

static Expression RemoveConvert(Expression expression)
Expand Down
67 changes: 49 additions & 18 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
if (joinedMapping.Value1 is EntityProjectionExpression entityProjection1
&& joinedMapping.Value2 is EntityProjectionExpression entityProjection2)
{
HandleEntityMapping(joinedMapping.Key, entityProjection1, entityProjection2);
_projectionMapping[joinedMapping.Key] = LiftEntityProjectionFromSetOperands(entityProjection1, entityProjection2);
continue;
}

Expand Down Expand Up @@ -575,7 +575,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
_tables.Clear();
_tables.Add(setExpression);

void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpression projection1, EntityProjectionExpression projection2)
EntityProjectionExpression LiftEntityProjectionFromSetOperands(EntityProjectionExpression projection1, EntityProjectionExpression projection2)
{
var (entityType1, entityType2) = (projection1.EntityType, projection2.EntityType);

Expand All @@ -586,14 +586,13 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr
{
foreach (var property in GetAllPropertiesInHierarchy(entityType1))
{
propertyExpressions[property] = AddSetOperationColumnProjections(
propertyExpressions[property] = GenerateOuterSetOperationColumn(
property.GetColumnName(),
projection1.BindProperty(property),
projection2.BindProperty(property));
}

_projectionMapping[projectionMember] = new EntityProjectionExpression(entityType1, propertyExpressions);
return;
return new EntityProjectionExpression(entityType1, propertyExpressions);
}

// We're doing a set operation over two different entity types (within the same hierarchy).
Expand All @@ -606,16 +605,18 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr

if (commonParentEntityType == null)
{
throw new InvalidOperationException(RelationalStrings.SetOperationNotWithinEntityTypeHierarchy);
throw new InvalidOperationException("No common parent in set operation over different types!");
}

var properties1 = GetAllPropertiesInHierarchy(entityType1).ToList();
var properties2 = GetAllPropertiesInHierarchy(entityType2).ToList();
var allProperties1 = GetAllPropertiesInHierarchy(entityType1).ToList();
var allProperties2 = GetAllPropertiesInHierarchy(entityType2).ToList();
var properties1 = allProperties1.ToList();
var properties2 = allProperties2.ToList();

// First handle shared properties that come from common base entity types
foreach (var property in properties1.Intersect(properties2).ToArray())
{
propertyExpressions[property] = AddSetOperationColumnProjections(
propertyExpressions[property] = GenerateOuterSetOperationColumn(
property.GetColumnName(),
projection1.BindProperty(property),
projection2.BindProperty(property));
Expand All @@ -634,7 +635,7 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr
(p1, p2) => (p1, p2))
.ToArray())
{
var outerProjection = AddSetOperationColumnProjections(
var outerProjection = GenerateOuterSetOperationColumn(
group1.Key,
projection1.BindProperty(group1.First()),
projection2.BindProperty(group2.First()));
Expand All @@ -655,7 +656,7 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr
// Remaining properties exist only on one side, so inject a null constant projection on the other side.
foreach (var property in properties1)
{
propertyExpressions[property] = AddSetOperationColumnProjections(
propertyExpressions[property] = GenerateOuterSetOperationColumn(
property.GetColumnName(),
projection1.BindProperty(property),
new SqlConstantExpression(
Expand All @@ -665,7 +666,7 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr

foreach (var property in properties2)
{
propertyExpressions[property] = AddSetOperationColumnProjections(
propertyExpressions[property] = GenerateOuterSetOperationColumn(
property.GetColumnName(),
new SqlConstantExpression(
Constant(null, property.ClrType.MakeNullable()),
Expand All @@ -676,12 +677,12 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr
// Finally, the shaper will expect to read properties from unrelated siblings, since the set operations
// return type is the common ancestor. Add appropriate null constant projections for both sides.
// See #16215 for a possible optimization.
var unrelatedSiblingProperties = GetAllPropertiesInHierarchy(commonParentEntityType)
.Except(GetAllPropertiesInHierarchy(entityType1))
.Except(GetAllPropertiesInHierarchy(entityType2));
var unrelatedSiblingProperties = GetAllPropertiesInHierarchy(commonParentEntityType).ToList()
.Except(allProperties1)
.Except(allProperties2);
foreach (var property in unrelatedSiblingProperties)
{
propertyExpressions[property] = AddSetOperationColumnProjections(
propertyExpressions[property] = GenerateOuterSetOperationColumn(
property.GetColumnName(),
new SqlConstantExpression(
Constant(null, property.ClrType.MakeNullable()),
Expand All @@ -691,9 +692,39 @@ void HandleEntityMapping(ProjectionMember projectionMember, EntityProjectionExpr
property.GetRelationalTypeMapping()));
}

_projectionMapping[projectionMember] = new EntityProjectionExpression(commonParentEntityType, propertyExpressions);
var newEntityProjection = new EntityProjectionExpression(commonParentEntityType, propertyExpressions);

ColumnExpression AddSetOperationColumnProjections(string columnName, SqlExpression innerExpression1, SqlExpression innerExpression2)
// Also lift nested entity projections
foreach (var navigation in projection1.EntityType.GetTypesInHierarchy()
.SelectMany(EntityTypeExtensions.GetDeclaredNavigations))
{
var boundEntityShaperExpression1 = projection1.BindNavigation(navigation);
var boundEntityShaperExpression2 = projection2.BindNavigation(navigation);

if (boundEntityShaperExpression1 == null
&& boundEntityShaperExpression2 == null)
{
continue;
}

if (boundEntityShaperExpression1 == null
&& boundEntityShaperExpression2 != null
|| boundEntityShaperExpression2 == null
&& boundEntityShaperExpression1 != null)
{
throw new InvalidOperationException(CoreStrings.SetOperationWithDifferentIncludesInOperands);
}

var newInnerEntityProjection = LiftEntityProjectionFromSetOperands(
(EntityProjectionExpression)boundEntityShaperExpression1.ValueBufferExpression,
(EntityProjectionExpression)boundEntityShaperExpression2.ValueBufferExpression);
boundEntityShaperExpression1 = boundEntityShaperExpression1.Update(newInnerEntityProjection);
newEntityProjection.AddNavigationBinding(navigation, boundEntityShaperExpression1);
}

return newEntityProjection;

ColumnExpression GenerateOuterSetOperationColumn(string columnName, SqlExpression innerExpression1, SqlExpression innerExpression2)
{
if (expressionsByColumnName.TryGetValue(columnName, out var outerProjection))
{
Expand Down
22 changes: 17 additions & 5 deletions test/EFCore.Specification.Tests/Query/InheritanceTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ protected virtual void UseTransaction(DatabaseFacade facade, IDbContextTransacti
{
}

#region Set operations

[ConditionalFact]
public virtual void Union_of_supertype_with_itself_with_properties_mapped_to_same_column()
{
Expand Down Expand Up @@ -483,47 +485,57 @@ public virtual void Concat_siblings_with_two_properties_mapped_to_same_column()
public virtual void OfType_Union_subquery()
{
using var context = CreateContext();
context.Set<Animal>()
var kiwis = context.Set<Animal>()
.OfType<Kiwi>()
.Union(
context.Set<Animal>()
.OfType<Kiwi>())
.Where(o => o.FoundOn == Island.North)
.Where(o => o.FoundOn == Island.South)
.ToList();

Assert.Equal("Great spotted kiwi", Assert.Single(kiwis).Name);
}

[ConditionalFact]
public virtual void OfType_Union_OfType()
{
using var context = CreateContext();
context.Set<Bird>()
var kiwis = context.Set<Bird>()
.OfType<Kiwi>()
.Union(context.Set<Bird>())
.OfType<Kiwi>()
.ToList();

Assert.Equal("Great spotted kiwi", Assert.Single(kiwis).Name);
}

[ConditionalFact]
public virtual void Subquery_OfType()
{
using var context = CreateContext();
context.Set<Bird>()
var kiwis = context.Set<Bird>()
.Take(5)
.Distinct() // Causes pushdown
.OfType<Kiwi>()
.ToList();

Assert.Equal("Great spotted kiwi", Assert.Single(kiwis).Name);
}

[ConditionalFact]
public virtual void Union_entity_equality()
{
using var context = CreateContext();
context.Set<Kiwi>()
var kiwis = context.Set<Kiwi>()
.Union(context.Set<Eagle>().Cast<Bird>())
.Where(b => b == null)
.ToList();

Assert.Empty(kiwis);
}

#endregion Set operations

[ConditionalFact]
public virtual void Setting_foreign_key_to_a_different_type_throws()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,11 @@ public virtual Task Select_Union_different_fields_in_anonymous_with_subquery(boo
.Union(
ss.Set<Customer>()
.Where(c => c.City == "London")
.Select(c => new { Foo = c.PostalCode, Customer = c })), // Foo is PostalCode
entryCount: 7);
.Select(c => new { Foo = c.PostalCode, Customer = c })) // Foo is PostalCode
.OrderBy(c => c.Foo)
.Skip(1)
.Take(10),
entryCount: 6);
}

[ConditionalTheory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ FROM [Animal] AS [a]
WHERE [a].[Discriminator] = N'Kiwi'");
}

#region Set operations

public override void Union_of_supertype_with_itself_with_properties_mapped_to_same_column()
{
base.Union_of_supertype_with_itself_with_properties_mapped_to_same_column();
Expand Down Expand Up @@ -483,15 +485,15 @@ WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND ([a].[Discriminator] = N'Ki
FROM [Animal] AS [a0]
WHERE [a0].[Discriminator] IN (N'Eagle', N'Kiwi') AND ([a0].[Discriminator] = N'Kiwi')
) AS [t]
WHERE [t].[FoundOn] = CAST(0 AS tinyint)");
WHERE [t].[FoundOn] = CAST(1 AS tinyint)");
}

public override void OfType_Union_OfType()
{
base.OfType_Union_OfType();

AssertSql(
@"SELECT [t].[Species], [t].[CountryId], [t].[Discriminator], [t].[Name], [t].[EagleId], [t].[IsFlightless], [t].[Group], [t].[FoundOn]
@"SELECT [t].[Species], [t].[CountryId], [t].[Discriminator], [t].[Name], [t].[EagleId], [t].[IsFlightless], [t].[FoundOn]
FROM (
SELECT [a].[Species], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[EagleId], [a].[IsFlightless], [a].[FoundOn], NULL AS [Group]
FROM [Animal] AS [a]
Expand All @@ -500,7 +502,8 @@ WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND ([a].[Discriminator] = N'Ki
SELECT [a0].[Species], [a0].[CountryId], [a0].[Discriminator], [a0].[Name], [a0].[EagleId], [a0].[IsFlightless], [a0].[FoundOn], [a0].[Group]
FROM [Animal] AS [a0]
WHERE [a0].[Discriminator] IN (N'Eagle', N'Kiwi')
) AS [t]");
) AS [t]
WHERE [t].[Discriminator] = N'Kiwi'");
}

public override void Subquery_OfType()
Expand Down Expand Up @@ -537,6 +540,8 @@ FROM [Animal] AS [a0]
WHERE CAST(0 AS bit) = CAST(1 AS bit)");
}

#endregion Set operations

protected override void UseTransaction(DatabaseFacade facade, IDbContextTransaction transaction)
=> facade.UseTransaction(transaction.GetDbTransaction());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,21 @@ public override async Task Select_Union_different_fields_in_anonymous_with_subqu
await base.Select_Union_different_fields_in_anonymous_with_subquery(async);

AssertSql(
@"SELECT [c].[City] AS [Foo], [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[City] = N'Berlin'
UNION
SELECT [c0].[PostalCode] AS [Foo], [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] = N'London'");
@"@__p_0='1'
@__p_1='10'
SELECT [t].[Foo], [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region]
FROM (
SELECT [c].[City] AS [Foo], [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[City] = N'Berlin'
UNION
SELECT [c0].[PostalCode] AS [Foo], [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] = N'London'
) AS [t]
ORDER BY [t].[Foo]
OFFSET @__p_0 ROWS FETCH NEXT @__p_1 ROWS ONLY");
}

public override async Task Union_Include(bool async)
Expand Down

0 comments on commit c1513db

Please sign in to comment.