Skip to content

Commit

Permalink
Fix comparison of nullable values (#33757)
Browse files Browse the repository at this point in the history
Fix comparison of nullable values

In C# an ordered comparison (<, >, <=, >=) between two nullable values always
returns a boolean value: if either operand is null, the result is false;
otherwise, the result is that of the comparison of the (non-null) values.

Fixes #33752
  • Loading branch information
ranma42 authored May 31, 2024
1 parent 3fa01db commit b77d2f4
Show file tree
Hide file tree
Showing 16 changed files with 396 additions and 38 deletions.
22 changes: 22 additions & 0 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,28 @@ protected virtual SqlExpression VisitSqlBinary(
nullable = leftNullable || rightNullable;
var result = sqlBinaryExpression.Update(left, right);

if (nullable && !optimize && result.OperatorType
is ExpressionType.GreaterThan
or ExpressionType.GreaterThanOrEqual
or ExpressionType.LessThan
or ExpressionType.LessThanOrEqual)
{
// https://learn.microsoft.com/en-us/dotnet/csharp/language-reference/builtin-types/nullable-value-types#lifted-operators
// For the comparison operators <, >, <=, and >=, if one or both
// operands are null, the result is false; otherwise, the contained
// values of operands are compared.

// if either operand is NULL, the SQL comparison would return NULL;
// to match the C# semantics, replace expr with
// CASE WHEN expr THEN TRUE ELSE FALSE

nullable = false;
return _sqlExpressionFactory.Case(
[new(result, _sqlExpressionFactory.Constant(true, result.TypeMapping))],
_sqlExpressionFactory.Constant(false, result.TypeMapping)
);
}

return result is SqlBinaryExpression sqlBinaryResult
&& sqlBinaryExpression.OperatorType is ExpressionType.AndAlso or ExpressionType.OrElse
? SimplifyLogicalSqlBinaryExpression(sqlBinaryResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,15 @@ public SqliteStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping);
pattern = _sqlExpressionFactory.ApplyTypeMapping(pattern, stringTypeMapping);

// Note: we add IS NOT NULL checks here since we don't do null semantics/compensation for comparison (greater-than)
return
_sqlExpressionFactory.AndAlso(
_sqlExpressionFactory.IsNotNull(instance),
_sqlExpressionFactory.AndAlso(
_sqlExpressionFactory.IsNotNull(pattern),
_sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"instr",
new[] { instance, pattern },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(int)),
_sqlExpressionFactory.Constant(0))));
_sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"instr",
new[] { instance, pattern },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(int)),
_sqlExpressionFactory.Constant(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,44 @@ FROM root c
""");
});

public override Task String_Contains_negated_in_predicate(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.String_Contains_negated_in_predicate(a);
AssertSql(
"""
SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND NOT(CONTAINS(c["CompanyName"], c["ContactName"])))
""");
});

public override Task String_Contains_negated_in_projection(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.String_Contains_negated_in_projection(a);
AssertSql(
"""
SELECT VALUE {"Id" : c["CustomerID"], "Value" : NOT(CONTAINS(c["CompanyName"], c["ContactName"]))}
FROM root c
WHERE (c["Discriminator"] = "Customer")
""");
});

[ConditionalTheory(Skip = "issue #33858")]
public override Task String_Contains_in_projection(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.String_Contains_in_projection(a);
AssertSql("");
});

public override Task String_Join_over_non_nullable_column(bool async)
=> AssertTranslationFailed(() => base.String_Join_over_non_nullable_column(async));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,9 @@ await AssertQueryScalar(
e => (e.BoolA ? e.NullableBoolA != e.NullableBoolB : e.BoolC) != e.BoolB
? e.BoolA
: e.NullableBoolB == e.NullableBoolC).Select(e => e.Id));
await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(e => (e.BoolA ? e.NullableIntA : e.IntB) > e.IntC));
}

[ConditionalTheory]
Expand Down Expand Up @@ -1558,7 +1561,7 @@ public virtual async Task Negated_order_comparison_on_non_nullable_arguments_get
await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>().Where(e => !(e.IntA <= i)).Select(e => e.Id));
}

[ConditionalTheory(Skip = "issue #9544")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Negated_order_comparison_on_nullable_arguments_doesnt_get_optimized(bool async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,29 @@ public virtual Task String_Contains_Column(bool async)
async,
ss => ss.Set<Customer>().Where(c => c.ContactName.Contains(c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Contains_in_projection(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Select(c => new { Id = c.CustomerID, Value = c.CompanyName.Contains(c.ContactName) }),
elementSorter: e => e.Id);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Contains_negated_in_predicate(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => !c.CompanyName.Contains(c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Contains_negated_in_projection(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Select(c => new { Id = c.CustomerID, Value = !c.CompanyName.Contains(c.ContactName) }),
elementSorter: e => e.Id);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_FirstOrDefault_MethodCall(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4126,10 +4126,10 @@ INNER JOIN (
FROM [LevelTwo] AS [l0]
) AS [l1]
GROUP BY [l1].[Key]
) AS [l2] ON [l].[Id] = [l2].[Key] AND CAST(0 AS bit) = CASE
) AS [l2] ON [l].[Id] = [l2].[Key] AND CASE
WHEN [l2].[Sum] <= 10 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END = CAST(0 AS bit)
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -974,10 +974,10 @@ WHEN [l2].[OneToOne_Required_PK_Date] IS NOT NULL AND [l2].[Level1_Required_Id]
WHERE [l2].[OneToOne_Required_PK_Date] IS NOT NULL AND [l2].[Level1_Required_Id] IS NOT NULL AND [l2].[OneToMany_Required_Inverse2Id] IS NOT NULL
) AS [s]
GROUP BY [s].[Key]
) AS [s1] ON [l].[Id] = [s1].[Key] AND CAST(0 AS bit) = CASE
) AS [s1] ON [l].[Id] = [s1].[Key] AND CASE
WHEN [s1].[Sum] <= 10 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END = CAST(0 AS bit)
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2750,6 +2750,46 @@ ELSE CHARINDEX(CONVERT(varchar(11), [o].[OrderID]), '123') - 1
""");
}

public override async Task String_Contains_in_projection(bool async)
{
await base.String_Contains_in_projection(async);

AssertSql(
"""
SELECT [c].[CustomerID] AS [Id], CASE
WHEN [c].[ContactName] IS NOT NULL AND (CHARINDEX([c].[ContactName], [c].[CompanyName]) > 0 OR [c].[ContactName] LIKE N'') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [Value]
FROM [Customers] AS [c]
""");
}

public override async Task String_Contains_negated_in_predicate(bool async)
{
await base.String_Contains_negated_in_predicate(async);

AssertSql(
"""
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[ContactName] IS NULL OR (CHARINDEX([c].[ContactName], [c].[CompanyName]) <= 0 AND [c].[ContactName] NOT LIKE N'')
""");
}

public override async Task String_Contains_negated_in_projection(bool async)
{
await base.String_Contains_negated_in_projection(async);

AssertSql(
"""
SELECT [c].[CustomerID] AS [Id], CASE
WHEN [c].[ContactName] IS NULL OR (CHARINDEX([c].[ContactName], [c].[CompanyName]) <= 0 AND [c].[ContactName] NOT LIKE N'') THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [Value]
FROM [Customers] AS [c]
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task StandardDeviation(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,17 @@ ELSE CASE
ELSE CAST(0 AS bit)
END
END = CAST(1 AS bit)
""",
//
"""
SELECT CASE
WHEN CASE
WHEN [e].[BoolA] = CAST(1 AS bit) THEN [e].[NullableIntA]
ELSE [e].[IntB]
END > [e].[IntC] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
FROM [Entities1] AS [e]
""");
}

Expand Down Expand Up @@ -2518,7 +2529,49 @@ public override async Task Negated_order_comparison_on_nullable_arguments_doesnt
await base.Negated_order_comparison_on_nullable_arguments_doesnt_get_optimized(async);

AssertSql(
@"");
"""
@__i_0='1' (Nullable = true)

SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableIntA] > @__i_0 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CAST(0 AS bit)
""",
//
"""
@__i_0='1' (Nullable = true)

SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableIntA] >= @__i_0 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CAST(0 AS bit)
""",
//
"""
@__i_0='1' (Nullable = true)

SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableIntA] < @__i_0 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CAST(0 AS bit)
""",
//
"""
@__i_0='1' (Nullable = true)

SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableIntA] <= @__i_0 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CAST(0 AS bit)
""");
}

public override async Task Nullable_column_info_propagates_inside_binary_AndAlso(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public override async Task FromSqlRaw_queryable_composed(bool async)
FROM (
SELECT * FROM "Customers"
) AS "m"
WHERE "m"."ContactName" IS NOT NULL AND instr("m"."ContactName", 'z') > 0
WHERE instr("m"."ContactName", 'z') > 0
""");
}

Expand Down Expand Up @@ -75,7 +75,7 @@ public override async Task FromSqlRaw_composed_with_common_table_expression(bool
)
SELECT * FROM "Customers2"
) AS "m"
WHERE "m"."ContactName" IS NOT NULL AND instr("m"."ContactName", 'z') > 0
WHERE instr("m"."ContactName", 'z') > 0
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public override async Task String_Contains_and_StartsWith_with_same_parameter(bo

SELECT "f"."Id", "f"."FirstName", "f"."LastName", "f"."NullableBool"
FROM "FunkyCustomers" AS "f"
WHERE ("f"."FirstName" IS NOT NULL AND instr("f"."FirstName", @__s_0) > 0) OR "f"."LastName" LIKE @__s_0_startswith ESCAPE '\'
WHERE instr("f"."FirstName", @__s_0) > 0 OR "f"."LastName" LIKE @__s_0_startswith ESCAPE '\'
""");
}

Expand Down
Loading

0 comments on commit b77d2f4

Please sign in to comment.