Skip to content

Commit

Permalink
Use XOR to translate some == and != expressions (#34124)
Browse files Browse the repository at this point in the history
* Use XOR to translate some `==` and `!=` expressions

When the parent expression is not a predicate, translate `x != y` to:
```sql
x ^ y
```

instead of

```sql
CASE
    WHEN x <> y THEN CAST(1 AS bit)
    ELSE CAST(0 AS bit)
END
```

Similarly, translate `x == y` to:

```sql
x ^ y ^ CAST(1 AS bit)
```

instead of

```sql
CASE
    WHEN x == y THEN CAST(1 AS bit)
    ELSE CAST(0 AS bit)
END
```

Contributes to #34001 for simple cases (comparison of BIT expressions).
  • Loading branch information
ranma42 authored Jul 3, 2024
1 parent 4649fb3 commit 0337960
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,32 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres

_isSearchCondition = parentIsSearchCondition;

if (!parentIsSearchCondition
&& newLeft.Type == typeof(bool) && newRight.Type == typeof(bool)
&& sqlBinaryExpression.OperatorType is ExpressionType.NotEqual or ExpressionType.Equal)
{
// on BIT, "lhs != rhs" is the same as "lhs ^ rhs", except that the
// first is a boolean, the second is a BIT
var result = _sqlExpressionFactory.MakeBinary(
ExpressionType.ExclusiveOr,
newLeft,
newRight,
sqlBinaryExpression.TypeMapping)!;

// "lhs == rhs" is the same as "NOT(lhs == rhs)" aka "lhs ^ rhs ^ 1"
if (sqlBinaryExpression.OperatorType is ExpressionType.Equal)
{
result = _sqlExpressionFactory.MakeBinary(
ExpressionType.ExclusiveOr,
result,
_sqlExpressionFactory.Constant(true, result.TypeMapping),
result.TypeMapping
)!;
}

return result;
}

sqlBinaryExpression = sqlBinaryExpression.Update(newLeft, newRight);
var condition = sqlBinaryExpression.OperatorType is ExpressionType.AndAlso
or ExpressionType.OrElse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1396,12 +1396,9 @@ public override async Task Where_bool_member_and_parameter_compared_to_binary_ex
SELECT [p].[ProductID], [p].[Discontinued], [p].[ProductName], [p].[SupplierID], [p].[UnitPrice], [p].[UnitsInStock]
FROM [Products] AS [p]
WHERE [p].[Discontinued] = CASE
WHEN CASE
WHEN [p].[ProductID] > 50 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> @__prm_0 THEN CAST(1 AS bit)
WHEN [p].[ProductID] > 50 THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
END ^ @__prm_0
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,7 @@ public override async Task Rewrite_compare_bool_with_bool(bool async)

AssertSql(
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -280,10 +277,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -364,10 +358,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -392,10 +383,7 @@ WHERE [e].[BoolA] <> [e].[NullableBoolB]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -476,10 +464,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -504,10 +489,7 @@ WHERE [e].[BoolA] <> [e].[NullableBoolB] OR [e].[NullableBoolB] IS NULL
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -588,10 +570,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -616,10 +595,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -700,10 +676,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -728,10 +701,7 @@ WHERE [e].[BoolA] <> [e].[NullableBoolB] OR [e].[NullableBoolB] IS NULL
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -812,10 +782,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -840,10 +807,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -924,10 +888,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -952,10 +913,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -1036,10 +994,7 @@ FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] AS [X]
FROM [Entities1] AS [e]
""",
//
Expand All @@ -1064,10 +1019,7 @@ WHERE [e].[BoolA] <> [e].[NullableBoolB] AND [e].[NullableBoolB] IS NOT NULL
""",
//
"""
SELECT [e].[Id], CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [X]
SELECT [e].[Id], [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) AS [X]
FROM [Entities1] AS [e]
""",
//
Expand Down Expand Up @@ -1756,10 +1708,7 @@ public override async Task Compare_complex_equal_equal_equal(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CASE
WHERE [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) = CASE
WHEN [e].[IntA] = [e].[IntB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
Expand Down Expand Up @@ -1798,10 +1747,7 @@ public override async Task Compare_complex_equal_not_equal_equal(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> CASE
WHERE [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit) <> CASE
WHEN [e].[IntA] = [e].[IntB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
Expand Down Expand Up @@ -1840,10 +1786,7 @@ public override async Task Compare_complex_not_equal_equal_equal(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CASE
WHERE [e].[BoolA] ^ [e].[BoolB] = CASE
WHEN [e].[IntA] = [e].[IntB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
Expand Down Expand Up @@ -1882,10 +1825,7 @@ public override async Task Compare_complex_not_equal_not_equal_equal(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> CASE
WHERE [e].[BoolA] ^ [e].[BoolB] <> CASE
WHEN [e].[IntA] = [e].[IntB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
Expand Down Expand Up @@ -1924,10 +1864,7 @@ public override async Task Compare_complex_not_equal_equal_not_equal(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END = CASE
WHERE [e].[BoolA] ^ [e].[BoolB] = CASE
WHEN [e].[IntA] <> [e].[IntB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
Expand Down Expand Up @@ -1966,10 +1903,7 @@ public override async Task Compare_complex_not_equal_not_equal_not_equal(bool as
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] <> [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> CASE
WHERE [e].[BoolA] ^ [e].[BoolB] <> CASE
WHEN [e].[IntA] <> [e].[IntB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
Expand Down Expand Up @@ -4512,10 +4446,7 @@ public override async Task Is_null_on_column_followed_by_OrElse_optimizes_nullab
SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableBoolA] IS NULL THEN CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
WHEN [e].[NullableBoolA] IS NULL THEN [e].[BoolA] ^ [e].[BoolB] ^ CAST(1 AS bit)
WHEN [e].[NullableBoolC] IS NULL THEN CASE
WHEN ([e].[NullableBoolA] <> [e].[NullableBoolC] OR [e].[NullableBoolA] IS NULL OR [e].[NullableBoolC] IS NULL) AND ([e].[NullableBoolA] IS NOT NULL OR [e].[NullableBoolC] IS NOT NULL) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
Expand Down

0 comments on commit 0337960

Please sign in to comment.