Skip to content

Commit

Permalink
Remove invalid optimization of CASE-WHEN expressions (dotnet#33754)
Browse files Browse the repository at this point in the history
* Add tests for `CASE WHEN END = const` optimization

* Support optimization of `CompareTo(a, b) == {-1,0,1}`

* Remove invalid optimization of `CASE WHEN ... END = const`

Fixes dotnet#33751
  • Loading branch information
ranma42 authored May 21, 2024
1 parent 0f23ab7 commit e969995
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ private SqlExpression OptimizeCompareTo(

return operatorType switch
{
// CompareTo(a, b) == 0 -> a == b
// CompareTo(a, b) == 1 -> a > b
// CompareTo(a, b) = -1 -> a < b
ExpressionType.Equal => (SqlExpression)Visit(
intValue switch
{
0 => _sqlExpressionFactory.Equal(testLeft, testRight),
1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
_ => _sqlExpressionFactory.LessThan(testLeft, testRight)
}),
// CompareTo(a, b) != 0 -> a != b
// CompareTo(a, b) != 1 -> a <= b
// CompareTo(a, b) != -1 -> a >= b
Expand Down Expand Up @@ -203,29 +213,13 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression;
var caseComponent = sqlBinaryExpression.Left as CaseExpression ?? sqlBinaryExpression.Right as CaseExpression;

// generic CASE statement comparison optimization:
// (CASE
// WHEN condition1 THEN result1
// WHEN condition2 THEN result2
// WHEN ...
// WHEN conditionN THEN resultN) == result1 -> condition1
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal
&& sqlConstantComponent?.Value is not null
&& caseComponent is { Operand: null, ElseResult: null })
{
var matchingCaseBlock = caseComponent.WhenClauses.FirstOrDefault(wc => sqlConstantComponent.Equals(wc.Result));
if (matchingCaseBlock != null)
{
return Visit(matchingCaseBlock.Test);
}
}

// CompareTo specific optimizations
if (sqlConstantComponent != null
&& IsCompareTo(caseComponent)
&& sqlConstantComponent.Value is int intValue and > -2 and < 2
&& sqlBinaryExpression.OperatorType
is ExpressionType.NotEqual
is ExpressionType.Equal
or ExpressionType.NotEqual
or ExpressionType.GreaterThan
or ExpressionType.GreaterThanOrEqual
or ExpressionType.LessThan
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.TestModels.NullSemanticsModel;

namespace Microsoft.EntityFrameworkCore.Query;
Expand Down Expand Up @@ -119,5 +120,20 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
modelBuilder.Entity<NullSemanticsEntity2>().Property(e => e.StringA).IsRequired();
modelBuilder.Entity<NullSemanticsEntity2>().Property(e => e.StringB).IsRequired();
modelBuilder.Entity<NullSemanticsEntity2>().Property(e => e.StringC).IsRequired();

modelBuilder.HasDbFunction(
typeof(NullSemanticsQueryFixtureBase).GetMethod(nameof(Cases)),
b => b.HasTranslation(args => new CaseExpression([
new CaseWhenClause(args[0], args[1]),
new CaseWhenClause(args[2], args[3]),
new CaseWhenClause(args[4], args[5]),
]))
);
}

public static int? Cases(bool c1, int v1, bool c2, int v2, bool c3, int v3) =>
c1 ? v1 :
c2 ? v2 :
c3 ? v3 :
null;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,62 @@ await AssertQuery(
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA != e.NullableStringB));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseWhen_equal_to_second_filter(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.Where(x => NullSemanticsQueryFixtureBase.Cases(
x.StringA == "Foo", 3,
x.StringB == "Foo", 2,
x.StringC == "Foo", 3
) == 2)
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseWhen_equal_to_first_or_third_filter(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.Where(x => NullSemanticsQueryFixtureBase.Cases(
x.StringA == "Foo", 3,
x.StringB == "Foo", 2,
x.StringC == "Foo", 3
) == 3)
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseWhen_equal_to_second_select(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.OrderBy(x => x.Id)
.Select(x => NullSemanticsQueryFixtureBase.Cases(
x.StringA == "Foo", 3,
x.StringB == "Foo", 2,
x.StringC == "Foo", 3
) == 2),
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task CaseWhen_equal_to_first_or_third_select(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>()
.OrderBy(x => x.Id)
.Select(x => NullSemanticsQueryFixtureBase.Cases(
x.StringA == "Foo", 3,
x.StringB == "Foo", 2,
x.StringC == "Foo", 3
) == 3),
assertOrder: true
);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task False_compared_to_negated_is_null(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,84 @@ WHERE [e].[NullableStringA] IS NULL
""");
}

public override async Task CaseWhen_equal_to_second_filter(bool async)
{
await base.CaseWhen_equal_to_second_filter(async);

AssertSql(
"""
SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[StringA] = N'Foo' THEN 3
WHEN [e].[StringB] = N'Foo' THEN 2
WHEN [e].[StringC] = N'Foo' THEN 3
END = 2
""");
}

public override async Task CaseWhen_equal_to_first_or_third_filter(bool async)
{
await base.CaseWhen_equal_to_first_or_third_filter(async);

AssertSql(
"""
SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[StringA] = N'Foo' THEN 3
WHEN [e].[StringB] = N'Foo' THEN 2
WHEN [e].[StringC] = N'Foo' THEN 3
END = 3
""");
}

public override async Task CaseWhen_equal_to_second_select(bool async)
{
await base.CaseWhen_equal_to_second_select(async);

AssertSql(
"""
SELECT CASE
WHEN CASE
WHEN [e].[StringA] = N'Foo' THEN 3
WHEN [e].[StringB] = N'Foo' THEN 2
WHEN [e].[StringC] = N'Foo' THEN 3
END = 2 AND CASE
WHEN [e].[StringA] = N'Foo' THEN 3
WHEN [e].[StringB] = N'Foo' THEN 2
WHEN [e].[StringC] = N'Foo' THEN 3
END IS NOT NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
FROM [Entities1] AS [e]
ORDER BY [e].[Id]
""");
}

public override async Task CaseWhen_equal_to_first_or_third_select(bool async)
{
await base.CaseWhen_equal_to_first_or_third_select(async);

AssertSql(
"""
SELECT CASE
WHEN CASE
WHEN [e].[StringA] = N'Foo' THEN 3
WHEN [e].[StringB] = N'Foo' THEN 2
WHEN [e].[StringC] = N'Foo' THEN 3
END = 3 AND CASE
WHEN [e].[StringA] = N'Foo' THEN 3
WHEN [e].[StringB] = N'Foo' THEN 2
WHEN [e].[StringC] = N'Foo' THEN 3
END IS NOT NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
FROM [Entities1] AS [e]
ORDER BY [e].[Id]
""");
}

public override async Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
await base.Multiple_non_equality_comparisons_with_null_in_the_middle(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,78 @@ SELECT 1
""");
}

public override async Task CaseWhen_equal_to_second_filter(bool async)
{
await base.CaseWhen_equal_to_second_filter(async);

AssertSql(
"""
SELECT "e"."Id", "e"."BoolA", "e"."BoolB", "e"."BoolC", "e"."IntA", "e"."IntB", "e"."IntC", "e"."NullableBoolA", "e"."NullableBoolB", "e"."NullableBoolC", "e"."NullableIntA", "e"."NullableIntB", "e"."NullableIntC", "e"."NullableStringA", "e"."NullableStringB", "e"."NullableStringC", "e"."StringA", "e"."StringB", "e"."StringC"
FROM "Entities1" AS "e"
WHERE CASE
WHEN "e"."StringA" = 'Foo' THEN 3
WHEN "e"."StringB" = 'Foo' THEN 2
WHEN "e"."StringC" = 'Foo' THEN 3
END = 2
""");
}

public override async Task CaseWhen_equal_to_first_or_third_filter(bool async)
{
await base.CaseWhen_equal_to_first_or_third_filter(async);

AssertSql(
"""
SELECT "e"."Id", "e"."BoolA", "e"."BoolB", "e"."BoolC", "e"."IntA", "e"."IntB", "e"."IntC", "e"."NullableBoolA", "e"."NullableBoolB", "e"."NullableBoolC", "e"."NullableIntA", "e"."NullableIntB", "e"."NullableIntC", "e"."NullableStringA", "e"."NullableStringB", "e"."NullableStringC", "e"."StringA", "e"."StringB", "e"."StringC"
FROM "Entities1" AS "e"
WHERE CASE
WHEN "e"."StringA" = 'Foo' THEN 3
WHEN "e"."StringB" = 'Foo' THEN 2
WHEN "e"."StringC" = 'Foo' THEN 3
END = 3
""");
}

public override async Task CaseWhen_equal_to_second_select(bool async)
{
await base.CaseWhen_equal_to_second_select(async);

AssertSql(
"""
SELECT CASE
WHEN "e"."StringA" = 'Foo' THEN 3
WHEN "e"."StringB" = 'Foo' THEN 2
WHEN "e"."StringC" = 'Foo' THEN 3
END = 2 AND CASE
WHEN "e"."StringA" = 'Foo' THEN 3
WHEN "e"."StringB" = 'Foo' THEN 2
WHEN "e"."StringC" = 'Foo' THEN 3
END IS NOT NULL
FROM "Entities1" AS "e"
ORDER BY "e"."Id"
""");
}

public override async Task CaseWhen_equal_to_first_or_third_select(bool async)
{
await base.CaseWhen_equal_to_first_or_third_select(async);

AssertSql(
"""
SELECT CASE
WHEN "e"."StringA" = 'Foo' THEN 3
WHEN "e"."StringB" = 'Foo' THEN 2
WHEN "e"."StringC" = 'Foo' THEN 3
END = 3 AND CASE
WHEN "e"."StringA" = 'Foo' THEN 3
WHEN "e"."StringB" = 'Foo' THEN 2
WHEN "e"."StringC" = 'Foo' THEN 3
END IS NOT NULL
FROM "Entities1" AS "e"
ORDER BY "e"."Id"
""");
}

public override async Task Bool_equal_nullable_bool_HasValue(bool async)
{
await base.Bool_equal_nullable_bool_HasValue(async);
Expand Down

0 comments on commit e969995

Please sign in to comment.