Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correct nullability for set operations
Browse files Browse the repository at this point in the history
Fixes #18135
roji committed Oct 17, 2019
1 parent 59fbd88 commit 69cc698
Showing 7 changed files with 101 additions and 15 deletions.
54 changes: 39 additions & 15 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
@@ -508,7 +508,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
if (joinedMapping.Value1 is EntityProjectionExpression entityProjection1
&& joinedMapping.Value2 is EntityProjectionExpression entityProjection2)
{
handleEntityMapping(joinedMapping.Key, select1, entityProjection1, select2, entityProjection2);
HandleEntityMapping(joinedMapping.Key, select1, entityProjection1, select2, entityProjection2);
continue;
}

@@ -522,15 +522,24 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
throw new InvalidOperationException("Set operations over different store types are currently unsupported");
}

var alias = generateUniqueAlias(
var alias = GenerateUniqueAlias(
joinedMapping.Key.Last?.Name
?? (innerColumn1 as ColumnExpression)?.Name
?? "c");

var innerProjection = new ProjectionExpression(innerColumn1, alias);
select1._projection.Add(innerProjection);
select2._projection.Add(new ProjectionExpression(innerColumn2, alias));
_projectionMapping[joinedMapping.Key] = new ColumnExpression(innerProjection, setExpression);
var innerProjection1 = new ProjectionExpression(innerColumn1, alias);
var innerProjection2 = new ProjectionExpression(innerColumn2, alias);
select1._projection.Add(innerProjection1);
select2._projection.Add(innerProjection2);
var outerProjection = new ColumnExpression(innerProjection1, setExpression);

if (IsNullableProjection(innerProjection1)
|| IsNullableProjection(innerProjection2))
{
outerProjection = outerProjection.MakeNullable();
}

_projectionMapping[joinedMapping.Key] = outerProjection;
continue;
}

@@ -548,7 +557,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
_tables.Clear();
_tables.Add(setExpression);

void handleEntityMapping(
void HandleEntityMapping(
ProjectionMember projectionMember,
SelectExpression select1, EntityProjectionExpression projection1,
SelectExpression select2, EntityProjectionExpression projection2)
@@ -562,23 +571,30 @@ void handleEntityMapping(
var propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
foreach (var property in GetAllPropertiesInHierarchy(projection1.EntityType))
{
propertyExpressions[property] = addSetOperationColumnProjections(
propertyExpressions[property] = AddSetOperationColumnProjections(
select1, projection1.BindProperty(property),
select2, projection2.BindProperty(property));
}

_projectionMapping[projectionMember] = new EntityProjectionExpression(projection1.EntityType, propertyExpressions);
}

ColumnExpression addSetOperationColumnProjections(
ColumnExpression AddSetOperationColumnProjections(
SelectExpression select1, ColumnExpression column1,
SelectExpression select2, ColumnExpression column2)
{
var alias = generateUniqueAlias(column1.Name);
var innerProjection = new ProjectionExpression(column1, alias);
select1._projection.Add(innerProjection);
select2._projection.Add(new ProjectionExpression(column2, alias));
var outerProjection = new ColumnExpression(innerProjection, setExpression);
var alias = GenerateUniqueAlias(column1.Name);
var innerProjection1 = new ProjectionExpression(column1, alias);
var innerProjection2 = new ProjectionExpression(column2, alias);
select1._projection.Add(innerProjection1);
select2._projection.Add(innerProjection2);
var outerProjection = new ColumnExpression(innerProjection1, setExpression);
if (IsNullableProjection(innerProjection1)
|| IsNullableProjection(innerProjection2))
{
outerProjection = outerProjection.MakeNullable();
}

if (select1._identifier.Contains(column1))
{
_identifier.Add(outerProjection);
@@ -587,7 +603,7 @@ ColumnExpression addSetOperationColumnProjections(
return outerProjection;
}

string generateUniqueAlias(string baseAlias)
string GenerateUniqueAlias(string baseAlias)
{
var currentAlias = baseAlias ?? "";
var counter = 0;
@@ -598,6 +614,14 @@ string generateUniqueAlias(string baseAlias)

return currentAlias;
}

static bool IsNullableProjection(ProjectionExpression projectionExpression)
=> projectionExpression.Expression switch
{
ColumnExpression columnExpression => columnExpression.IsNullable,
SqlConstantExpression sqlConstantExpression => sqlConstantExpression.Value == null,
_ => true,
};
}

private ColumnExpression GenerateOuterColumn(SqlExpression projection, string alias = null)
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ public override void Include_Union_different_includes_throws() { }
public override Task SubSelect_Union(bool isAsync) => Task.CompletedTask;
public override Task Client_eval_Union_FirstOrDefault(bool isAsync) => Task.CompletedTask;
public override Task GroupBy_Select_Union(bool isAsync) => Task.CompletedTask;
public override Task Union_over_columns_with_different_nullability(bool isAsync) => Task.CompletedTask;
public override Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType) => Task.CompletedTask;
}
}
Original file line number Diff line number Diff line change
@@ -5709,5 +5709,21 @@ public virtual Task Null_check_different_structure_does_not_remove_null_checks(b
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3.Name) == "L4 01"));
}

[ConditionalFact]
public virtual void Union_over_entities_with_different_nullability()
{
using var ctx = CreateContext();

var query = ctx.Set<Level1>()
.GroupJoin(ctx.Set<Level2>(), l1 => l1.Id, l2 => l2.Level1_Optional_Id, (l1, l2s) => new { l1, l2s })
.SelectMany(g => g.l2s.DefaultIfEmpty(), (g, l2) => new { g.l1, l2 })
.Concat(ctx.Set<Level2>().GroupJoin(ctx.Set<Level1>(), l2 => l2.Level1_Optional_Id, l1 => l1.Id, (l2, l1s) => new { l2, l1s })
.SelectMany(g => g.l1s.DefaultIfEmpty(), (g, l1) => new { l1, g.l2 })
.Where(e => e.l1.Equals(null)))
.Select(e => e.l1.Id);

var result = query.ToList();
}
}
}
Original file line number Diff line number Diff line change
@@ -169,6 +169,10 @@ public override void Member_pushdown_with_collection_navigation_in_the_middle()
{
}

public override void Union_over_entities_with_different_nullability()
{
}

[ConditionalTheory(Skip = "Issue#16752")]
public override Task Include_inside_subquery(bool isAsync)
{
Original file line number Diff line number Diff line change
@@ -444,6 +444,17 @@ public virtual Task GroupBy_Select_Union(bool isAsync)
.Select(g => new { CustomerID = g.Key, Count = g.Count() })));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Union_over_columns_with_different_nullability(bool isAsync)
{
return AssertQuery(
isAsync, ss => ss.Set<Customer>()
.Select(c => "NonNullableConstant")
.Concat(ss.Set<Customer>()
.Select(c => (string)null)));
}

[ConditionalTheory]
#pragma warning disable xUnit1016 // MemberData must reference a public member
[MemberData(nameof(GetSetOperandTestCases))]
Original file line number Diff line number Diff line change
@@ -4357,6 +4357,24 @@ ELSE [l2].[Name]
END IS NOT NULL");
}

public override void Union_over_entities_with_different_nullability()
{
base.Union_over_entities_with_different_nullability();

AssertSql(
@"SELECT [t].[Id]
FROM (
SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id], [l0].[Id] AS [Id0], [l0].[Date] AS [Date0], [l0].[Level1_Optional_Id], [l0].[Level1_Required_Id], [l0].[Name] AS [Name0], [l0].[OneToMany_Optional_Inverse2Id], [l0].[OneToMany_Optional_Self_Inverse2Id], [l0].[OneToMany_Required_Inverse2Id], [l0].[OneToMany_Required_Self_Inverse2Id], [l0].[OneToOne_Optional_PK_Inverse2Id], [l0].[OneToOne_Optional_Self2Id]
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id]
UNION ALL
SELECT [l2].[Id], [l2].[Date], [l2].[Name], [l2].[OneToMany_Optional_Self_Inverse1Id], [l2].[OneToMany_Required_Self_Inverse1Id], [l2].[OneToOne_Optional_Self1Id], [l1].[Id] AS [Id0], [l1].[Date] AS [Date0], [l1].[Level1_Optional_Id], [l1].[Level1_Required_Id], [l1].[Name] AS [Name0], [l1].[OneToMany_Optional_Inverse2Id], [l1].[OneToMany_Optional_Self_Inverse2Id], [l1].[OneToMany_Required_Inverse2Id], [l1].[OneToMany_Required_Self_Inverse2Id], [l1].[OneToOne_Optional_PK_Inverse2Id], [l1].[OneToOne_Optional_Self2Id]
FROM [LevelTwo] AS [l1]
LEFT JOIN [LevelOne] AS [l2] ON [l1].[Level1_Optional_Id] = [l2].[Id]
WHERE [l2].[Id] IS NULL
) AS [t]");
}

private void AssertSql(params string[] expected)
{
Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
Original file line number Diff line number Diff line change
@@ -380,6 +380,18 @@ FROM [Customers] AS [c0]
GROUP BY [c0].[CustomerID]");
}

public override async Task Union_over_columns_with_different_nullability(bool isAsync)
{
await base.Union_over_columns_with_different_nullability(isAsync);

AssertSql(
@"SELECT N'NonNullableConstant' AS [c]
FROM [Customers] AS [c]
UNION ALL
SELECT NULL AS [c]
FROM [Customers] AS [c0]");
}

public override async Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType)
{
await base.Union_over_different_projection_types(isAsync, leftType, rightType);

0 comments on commit 69cc698

Please sign in to comment.