Skip to content

Commit

Permalink
Additional null semantics improvements - not treating column as nulla…
Browse files Browse the repository at this point in the history
…ble if the null check x.prop != null is present before

Currently we blindly apply full null semantics expansion to comparisons with nullable columns, however is some cases columns that seem nullable can never be nullable at the time comparison is made because nulls have been filtered out, e.g. c.Name != null && c.Name == c.SomeOtherName
Fix is to peek into left side of binary expression if its AndAlso node and find all null checks for column expressions. When visiting the right side of the expression, all those null-checked columns can be treated as non-nullable.
  • Loading branch information
maumar committed Sep 12, 2019
1 parent 6407e8b commit 1730d69
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class NullSemanticsRewritingExpressionVisitor : ExpressionVisitor
private readonly ISqlExpressionFactory _sqlExpressionFactory;

private bool _isNullable;
private readonly List<ColumnExpression> _nonNullableColumns = new List<ColumnExpression>();

public NullSemanticsRewritingExpressionVisitor(ISqlExpressionFactory sqlExpressionFactory)
{
Expand Down Expand Up @@ -73,7 +74,7 @@ private SqlConstantExpression VisitSqlConstantExpression(SqlConstantExpression s

private ColumnExpression VisitColumnExpression(ColumnExpression columnExpression)
{
_isNullable = columnExpression.IsNullable;
_isNullable = !_nonNullableColumns.Contains(columnExpression) && columnExpression.IsNullable;

return columnExpression;
}
Expand Down Expand Up @@ -191,13 +192,30 @@ private SqlFunctionExpression VisitSqlFunctionExpression(SqlFunctionExpression s
private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
{
_isNullable = false;

var nonNullableColumns = new List<ColumnExpression>();
if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)
{
nonNullableColumns = FindNonNullableColumns(sqlBinaryExpression.Left);
}

var newLeft = (SqlExpression)Visit(sqlBinaryExpression.Left);
var leftNullable = _isNullable;

_isNullable = false;
if (nonNullableColumns.Count > 0)
{
_nonNullableColumns.AddRange(nonNullableColumns);
}

var newRight = (SqlExpression)Visit(sqlBinaryExpression.Right);
var rightNullable = _isNullable;

foreach (var nonNullableColumn in nonNullableColumns)
{
_nonNullableColumns.Remove(nonNullableColumn);
}

if (sqlBinaryExpression.OperatorType == ExpressionType.Coalesce)
{
_isNullable = leftNullable && rightNullable;
Expand Down Expand Up @@ -317,6 +335,40 @@ private SqlBinaryExpression VisitSqlBinaryExpression(SqlBinaryExpression sqlBina
return sqlBinaryExpression.Update(newLeft, newRight);
}

private List<ColumnExpression> FindNonNullableColumns(SqlExpression sqlExpression)
{
var result = new List<ColumnExpression>();
if (sqlExpression is SqlBinaryExpression sqlBinaryExpression)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
if (sqlBinaryExpression.Left is ColumnExpression leftColumn
&& leftColumn.IsNullable
&& sqlBinaryExpression.Right is SqlConstantExpression rightConstant
&& rightConstant.Value == null)
{
result.Add(leftColumn);
}

if (sqlBinaryExpression.Right is ColumnExpression rightColumn
&& rightColumn.IsNullable
&& sqlBinaryExpression.Left is SqlConstantExpression leftConstant
&& leftConstant.Value == null)
{
result.Add(rightColumn);
}
}

if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)
{
result.AddRange(FindNonNullableColumns(sqlBinaryExpression.Left));
result.AddRange(FindNonNullableColumns(sqlBinaryExpression.Right));
}
}

return result;
}

// ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null))
//
// a | b | F1 = a == b | F2 = (a != null && b != null) | F3 = F1 && F2 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,44 @@ public virtual void Null_semantics_contains()
}
}

[ConditionalFact]
public virtual void Null_semantics_with_null_check_simple()
{
using (var ctx = CreateContext())
{
var query1 = ctx.Entities1.Where(e => e.NullableIntA != null && e.NullableIntA == e.NullableIntB);
var result1 = query1.ToList();

var query2 = ctx.Entities1.Where(e => e.NullableIntA != null && e.NullableIntA != e.NullableIntB);
var result2 = query2.ToList();

var query3 = ctx.Entities1.Where(e => e.NullableIntA != null && e.NullableIntA == e.IntC);
var result3 = query3.ToList();

var query4 = ctx.Entities1.Where(e => e.NullableIntA != null && e.NullableIntB != null && e.NullableIntA == e.NullableIntB);
var result4 = query4.ToList();

var query5 = ctx.Entities1.Where(e => e.NullableIntA != null && e.NullableIntB != null && e.NullableIntA != e.NullableIntB);
var result5 = query5.ToList();
}
}

[ConditionalFact]
public virtual void Null_semantics_with_null_check_complex()
{
using (var ctx = CreateContext())
{
var query1 = ctx.Entities1.Where(e => e.NullableIntA != null && ((e.NullableIntC != e.NullableIntA) || (e.NullableIntB != null && e.NullableIntA != e.NullableIntB)));
var result1 = query1.ToList();

var query2 = ctx.Entities1.Where(e => e.NullableIntA != null && ((e.NullableIntC != e.NullableIntA) || (e.NullableIntA != e.NullableIntB)));
var result2 = query2.ToList();

var query3 = ctx.Entities1.Where(e => (e.NullableIntA != null || e.NullableIntB != null) && e.NullableIntA == e.NullableIntC);
var result3 = query3.ToList();
}
}

protected static TResult Maybe<TResult>(object caller, Func<TResult> expression)
where TResult : class
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public override void String_ends_with_equals_nullable_column()
FROM [FunkyCustomers] AS [f]
CROSS JOIN [FunkyCustomers] AS [f0]
WHERE (CASE
WHEN (([f0].[LastName] = N'') AND [f0].[LastName] IS NOT NULL) OR ([f].[FirstName] IS NOT NULL AND ([f0].[LastName] IS NOT NULL AND (((RIGHT([f].[FirstName], LEN([f0].[LastName])) = [f0].[LastName]) AND (RIGHT([f].[FirstName], LEN([f0].[LastName])) IS NOT NULL AND [f0].[LastName] IS NOT NULL)) OR (RIGHT([f].[FirstName], LEN([f0].[LastName])) IS NULL AND [f0].[LastName] IS NULL)))) THEN CAST(1 AS bit)
WHEN (([f0].[LastName] = N'') AND [f0].[LastName] IS NOT NULL) OR ([f].[FirstName] IS NOT NULL AND ([f0].[LastName] IS NOT NULL AND (RIGHT([f].[FirstName], LEN([f0].[LastName])) = [f0].[LastName]))) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = [f].[NullableBool]) AND [f].[NullableBool] IS NOT NULL");
}
Expand All @@ -39,7 +39,7 @@ public override void String_ends_with_not_equals_nullable_column()
FROM [FunkyCustomers] AS [f]
CROSS JOIN [FunkyCustomers] AS [f0]
WHERE (CASE
WHEN (([f0].[LastName] = N'') AND [f0].[LastName] IS NOT NULL) OR ([f].[FirstName] IS NOT NULL AND ([f0].[LastName] IS NOT NULL AND (((RIGHT([f].[FirstName], LEN([f0].[LastName])) = [f0].[LastName]) AND (RIGHT([f].[FirstName], LEN([f0].[LastName])) IS NOT NULL AND [f0].[LastName] IS NOT NULL)) OR (RIGHT([f].[FirstName], LEN([f0].[LastName])) IS NULL AND [f0].[LastName] IS NULL)))) THEN CAST(1 AS bit)
WHEN (([f0].[LastName] = N'') AND [f0].[LastName] IS NOT NULL) OR ([f].[FirstName] IS NOT NULL AND ([f0].[LastName] IS NOT NULL AND (RIGHT([f].[FirstName], LEN([f0].[LastName])) = [f0].[LastName]))) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> [f].[NullableBool]) OR [f].[NullableBool] IS NULL");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3669,7 +3669,7 @@ WHERE [g].[Discriminator] IN (N'Gear', N'Officer')
WHERE (([t0].[Discriminator] = N'Officer') AND [t0].[Discriminator] IS NOT NULL) AND ((
SELECT COUNT(*)
FROM [Gears] AS [g0]
WHERE ([g0].[Discriminator] IN (N'Gear', N'Officer') AND ([t0].[Nickname] IS NOT NULL AND (((([t0].[Nickname] = [g0].[LeaderNickname]) AND ([t0].[Nickname] IS NOT NULL AND [g0].[LeaderNickname] IS NOT NULL)) OR ([t0].[Nickname] IS NULL AND [g0].[LeaderNickname] IS NULL)) AND (([t0].[SquadId] = [g0].[LeaderSquadId]) AND [t0].[SquadId] IS NOT NULL)))) AND ([g0].[Nickname] = N'Dom')) > 0)");
WHERE ([g0].[Discriminator] IN (N'Gear', N'Officer') AND ([t0].[Nickname] IS NOT NULL AND ((([t0].[Nickname] = [g0].[LeaderNickname]) AND [g0].[LeaderNickname] IS NOT NULL) AND (([t0].[SquadId] = [g0].[LeaderSquadId]) AND [t0].[SquadId] IS NOT NULL)))) AND ([g0].[Nickname] = N'Dom')) > 0)");
}

public override void Select_null_conditional_with_inheritance()
Expand Down Expand Up @@ -4523,7 +4523,7 @@ LEFT JOIN (
SELECT [w0].[Id], [w0].[AmmunitionType], [w0].[IsAutomatic], [w0].[Name], [w0].[OwnerFullName], [w0].[SynergyWithId], (
SELECT COUNT(*)
FROM [Weapons] AS [w]
WHERE [t1].[FullName] IS NOT NULL AND ((([t1].[FullName] = [w].[OwnerFullName]) AND ([t1].[FullName] IS NOT NULL AND [w].[OwnerFullName] IS NOT NULL)) OR ([t1].[FullName] IS NULL AND [w].[OwnerFullName] IS NULL))) AS [c]
WHERE [t1].[FullName] IS NOT NULL AND (([t1].[FullName] = [w].[OwnerFullName]) AND [w].[OwnerFullName] IS NOT NULL)) AS [c]
FROM [Weapons] AS [w0]
LEFT JOIN (
SELECT [g1].[Nickname], [g1].[SquadId], [g1].[AssignedCityName], [g1].[CityOrBirthName], [g1].[Discriminator], [g1].[FullName], [g1].[HasSoulPatch], [g1].[LeaderNickname], [g1].[LeaderSquadId], [g1].[Rank]
Expand Down Expand Up @@ -6966,7 +6966,7 @@ WHERE [g].[Discriminator] IN (N'Gear', N'Officer')
OUTER APPLY (
SELECT TOP(50) [g0].[Nickname], [g0].[SquadId], [g0].[AssignedCityName], [g0].[CityOrBirthName], [g0].[Discriminator], [g0].[FullName], [g0].[HasSoulPatch], [g0].[LeaderNickname], [g0].[LeaderSquadId], [g0].[Rank]
FROM [Gears] AS [g0]
WHERE [g0].[Discriminator] IN (N'Gear', N'Officer') AND ([t0].[Nickname] IS NOT NULL AND (((([t0].[Nickname] = [g0].[LeaderNickname]) AND ([t0].[Nickname] IS NOT NULL AND [g0].[LeaderNickname] IS NOT NULL)) OR ([t0].[Nickname] IS NULL AND [g0].[LeaderNickname] IS NULL)) AND (([t0].[SquadId] = [g0].[LeaderSquadId]) AND [t0].[SquadId] IS NOT NULL)))
WHERE [g0].[Discriminator] IN (N'Gear', N'Officer') AND ([t0].[Nickname] IS NOT NULL AND ((([t0].[Nickname] = [g0].[LeaderNickname]) AND [g0].[LeaderNickname] IS NOT NULL) AND (([t0].[SquadId] = [g0].[LeaderSquadId]) AND [t0].[SquadId] IS NOT NULL)))
) AS [t1]
WHERE ([t0].[Discriminator] = N'Officer') AND [t0].[Discriminator] IS NOT NULL
ORDER BY [t].[Id], [t1].[Nickname], [t1].[SquadId]");
Expand Down Expand Up @@ -7004,7 +7004,7 @@ LEFT JOIN (
SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOrBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gears] AS [g]
WHERE [g].[Discriminator] IN (N'Gear', N'Officer')
) AS [t0] ON (((([t].[GearNickName] = [t0].[Nickname]) AND [t].[GearNickName] IS NOT NULL) AND (([t].[GearSquadId] = [t0].[SquadId]) AND [t].[GearSquadId] IS NOT NULL)) AND [t].[Note] IS NOT NULL) AND [t].[Note] IS NOT NULL
) AS [t0] ON ((([t].[GearNickName] = [t0].[Nickname]) AND [t].[GearNickName] IS NOT NULL) AND (([t].[GearSquadId] = [t0].[SquadId]) AND [t].[GearSquadId] IS NOT NULL)) AND [t].[Note] IS NOT NULL
ORDER BY [t].[Id], [t0].[Nickname], [t0].[SquadId]");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,50 @@ public override void Null_semantics_contains()
@"");
}

public override void Null_semantics_with_null_check_simple()
{
base.Null_semantics_with_null_check_simple();

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND (([e].[NullableIntA] = [e].[NullableIntB]) AND [e].[NullableIntB] IS NOT NULL)",
//
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND (([e].[NullableIntA] <> [e].[NullableIntB]) OR [e].[NullableIntB] IS NULL)",
//
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND ([e].[NullableIntA] = [e].[IntC])",
//
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] IS NOT NULL AND [e].[NullableIntB] IS NOT NULL) AND ([e].[NullableIntA] = [e].[NullableIntB])",
//
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] IS NOT NULL AND [e].[NullableIntB] IS NOT NULL) AND ([e].[NullableIntA] <> [e].[NullableIntB])");
}

public override void Null_semantics_with_null_check_complex()
{
base.Null_semantics_with_null_check_complex();

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND ((([e].[NullableIntC] <> [e].[NullableIntA]) OR [e].[NullableIntC] IS NULL) OR ([e].[NullableIntB] IS NOT NULL AND ([e].[NullableIntA] <> [e].[NullableIntB])))",
//
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND ((([e].[NullableIntC] <> [e].[NullableIntA]) OR [e].[NullableIntC] IS NULL) OR (([e].[NullableIntA] <> [e].[NullableIntB]) OR [e].[NullableIntB] IS NULL))",
//
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] IS NOT NULL OR [e].[NullableIntB] IS NOT NULL) AND ((([e].[NullableIntA] = [e].[NullableIntC]) AND ([e].[NullableIntA] IS NOT NULL AND [e].[NullableIntC] IS NOT NULL)) OR ([e].[NullableIntA] IS NULL AND [e].[NullableIntC] IS NULL))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ FROM [Orders] AS [o]
WHERE (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([c].[CustomerID] IS NOT NULL AND ((([c].[CustomerID] = [o0].[CustomerID]) AND ([c].[CustomerID] IS NOT NULL AND [o0].[CustomerID] IS NOT NULL)) OR ([c].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL))) AND ([o0].[OrderID] > 10260)) > 30");
WHERE ([c].[CustomerID] IS NOT NULL AND (([c].[CustomerID] = [o0].[CustomerID]) AND [o0].[CustomerID] IS NOT NULL)) AND ([o0].[OrderID] > 10260)) > 30");
}

public override async Task Client_groupjoin_with_orderby_key_descending(bool isAsync)
Expand Down
Loading

0 comments on commit 1730d69

Please sign in to comment.