From e969995b2701f228815259941467104b6ee8dbee Mon Sep 17 00:00:00 2001 From: Andrea Canciani Date: Tue, 21 May 2024 09:51:16 +0200 Subject: [PATCH] Remove invalid optimization of CASE-WHEN expressions (#33754) * Add tests for `CASE WHEN END = const` optimization * Support optimization of `CompareTo(a, b) == {-1,0,1}` * Remove invalid optimization of `CASE WHEN ... END = const` Fixes #33751 --- ...lExpressionSimplifyingExpressionVisitor.cs | 30 +++---- .../Query/NullSemanticsQueryFixtureBase.cs | 16 ++++ .../Query/NullSemanticsQueryTestBase.cs | 56 +++++++++++++ .../Query/NullSemanticsQuerySqlServerTest.cs | 78 +++++++++++++++++++ .../Query/NullSemanticsQuerySqliteTest.cs | 72 +++++++++++++++++ 5 files changed, 234 insertions(+), 18 deletions(-) diff --git a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs index a252cdf6ff9..731aa047a82 100644 --- a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs @@ -146,6 +146,16 @@ private SqlExpression OptimizeCompareTo( return operatorType switch { + // CompareTo(a, b) == 0 -> a == b + // CompareTo(a, b) == 1 -> a > b + // CompareTo(a, b) = -1 -> a < b + ExpressionType.Equal => (SqlExpression)Visit( + intValue switch + { + 0 => _sqlExpressionFactory.Equal(testLeft, testRight), + 1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), + _ => _sqlExpressionFactory.LessThan(testLeft, testRight) + }), // CompareTo(a, b) != 0 -> a != b // CompareTo(a, b) != 1 -> a <= b // CompareTo(a, b) != -1 -> a >= b @@ -203,29 +213,13 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression) sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression; var caseComponent = sqlBinaryExpression.Left as CaseExpression ?? sqlBinaryExpression.Right as CaseExpression; - // generic CASE statement comparison optimization: - // (CASE - // WHEN condition1 THEN result1 - // WHEN condition2 THEN result2 - // WHEN ... - // WHEN conditionN THEN resultN) == result1 -> condition1 - if (sqlBinaryExpression.OperatorType == ExpressionType.Equal - && sqlConstantComponent?.Value is not null - && caseComponent is { Operand: null, ElseResult: null }) - { - var matchingCaseBlock = caseComponent.WhenClauses.FirstOrDefault(wc => sqlConstantComponent.Equals(wc.Result)); - if (matchingCaseBlock != null) - { - return Visit(matchingCaseBlock.Test); - } - } - // CompareTo specific optimizations if (sqlConstantComponent != null && IsCompareTo(caseComponent) && sqlConstantComponent.Value is int intValue and > -2 and < 2 && sqlBinaryExpression.OperatorType - is ExpressionType.NotEqual + is ExpressionType.Equal + or ExpressionType.NotEqual or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.LessThan diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryFixtureBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryFixtureBase.cs index 8b2e940e023..57589808a11 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryFixtureBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryFixtureBase.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; using Microsoft.EntityFrameworkCore.TestModels.NullSemanticsModel; namespace Microsoft.EntityFrameworkCore.Query; @@ -119,5 +120,20 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con modelBuilder.Entity().Property(e => e.StringA).IsRequired(); modelBuilder.Entity().Property(e => e.StringB).IsRequired(); modelBuilder.Entity().Property(e => e.StringC).IsRequired(); + + modelBuilder.HasDbFunction( + typeof(NullSemanticsQueryFixtureBase).GetMethod(nameof(Cases)), + b => b.HasTranslation(args => new CaseExpression([ + new CaseWhenClause(args[0], args[1]), + new CaseWhenClause(args[2], args[3]), + new CaseWhenClause(args[4], args[5]), + ])) + ); } + + public static int? Cases(bool c1, int v1, bool c2, int v2, bool c3, int v3) => + c1 ? v1 : + c2 ? v2 : + c3 ? v3 : + null; } diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index 36befc97eac..9c512a44cbc 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -1720,6 +1720,62 @@ await AssertQuery( ss => ss.Set().Where(e => e.NullableStringA != e.NullableStringB)); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task CaseWhen_equal_to_second_filter(bool async) + => AssertQuery( + async, + ss => ss.Set() + .Where(x => NullSemanticsQueryFixtureBase.Cases( + x.StringA == "Foo", 3, + x.StringB == "Foo", 2, + x.StringC == "Foo", 3 + ) == 2) + ); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task CaseWhen_equal_to_first_or_third_filter(bool async) + => AssertQuery( + async, + ss => ss.Set() + .Where(x => NullSemanticsQueryFixtureBase.Cases( + x.StringA == "Foo", 3, + x.StringB == "Foo", 2, + x.StringC == "Foo", 3 + ) == 3) + ); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task CaseWhen_equal_to_second_select(bool async) + => AssertQuery( + async, + ss => ss.Set() + .OrderBy(x => x.Id) + .Select(x => NullSemanticsQueryFixtureBase.Cases( + x.StringA == "Foo", 3, + x.StringB == "Foo", 2, + x.StringC == "Foo", 3 + ) == 2), + assertOrder: true + ); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task CaseWhen_equal_to_first_or_third_select(bool async) + => AssertQuery( + async, + ss => ss.Set() + .OrderBy(x => x.Id) + .Select(x => NullSemanticsQueryFixtureBase.Cases( + x.StringA == "Foo", 3, + x.StringB == "Foo", 2, + x.StringC == "Foo", 3 + ) == 3), + assertOrder: true + ); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task False_compared_to_negated_is_null(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 0453bf31547..cecc6e20f4d 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -2717,6 +2717,84 @@ WHERE [e].[NullableStringA] IS NULL """); } + public override async Task CaseWhen_equal_to_second_filter(bool async) + { + await base.CaseWhen_equal_to_second_filter(async); + + AssertSql( + """ +SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC] +FROM [Entities1] AS [e] +WHERE CASE + WHEN [e].[StringA] = N'Foo' THEN 3 + WHEN [e].[StringB] = N'Foo' THEN 2 + WHEN [e].[StringC] = N'Foo' THEN 3 +END = 2 +"""); + } + + public override async Task CaseWhen_equal_to_first_or_third_filter(bool async) + { + await base.CaseWhen_equal_to_first_or_third_filter(async); + + AssertSql( + """ +SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC] +FROM [Entities1] AS [e] +WHERE CASE + WHEN [e].[StringA] = N'Foo' THEN 3 + WHEN [e].[StringB] = N'Foo' THEN 2 + WHEN [e].[StringC] = N'Foo' THEN 3 +END = 3 +"""); + } + + public override async Task CaseWhen_equal_to_second_select(bool async) + { + await base.CaseWhen_equal_to_second_select(async); + + AssertSql( + """ +SELECT CASE + WHEN CASE + WHEN [e].[StringA] = N'Foo' THEN 3 + WHEN [e].[StringB] = N'Foo' THEN 2 + WHEN [e].[StringC] = N'Foo' THEN 3 + END = 2 AND CASE + WHEN [e].[StringA] = N'Foo' THEN 3 + WHEN [e].[StringB] = N'Foo' THEN 2 + WHEN [e].[StringC] = N'Foo' THEN 3 + END IS NOT NULL THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +FROM [Entities1] AS [e] +ORDER BY [e].[Id] +"""); + } + + public override async Task CaseWhen_equal_to_first_or_third_select(bool async) + { + await base.CaseWhen_equal_to_first_or_third_select(async); + + AssertSql( + """ +SELECT CASE + WHEN CASE + WHEN [e].[StringA] = N'Foo' THEN 3 + WHEN [e].[StringB] = N'Foo' THEN 2 + WHEN [e].[StringC] = N'Foo' THEN 3 + END = 3 AND CASE + WHEN [e].[StringA] = N'Foo' THEN 3 + WHEN [e].[StringB] = N'Foo' THEN 2 + WHEN [e].[StringC] = N'Foo' THEN 3 + END IS NOT NULL THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END +FROM [Entities1] AS [e] +ORDER BY [e].[Id] +"""); + } + public override async Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async) { await base.Multiple_non_equality_comparisons_with_null_in_the_middle(async); diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs index d99f554154f..12b90479a50 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NullSemanticsQuerySqliteTest.cs @@ -113,6 +113,78 @@ SELECT 1 """); } + public override async Task CaseWhen_equal_to_second_filter(bool async) + { + await base.CaseWhen_equal_to_second_filter(async); + + AssertSql( + """ +SELECT "e"."Id", "e"."BoolA", "e"."BoolB", "e"."BoolC", "e"."IntA", "e"."IntB", "e"."IntC", "e"."NullableBoolA", "e"."NullableBoolB", "e"."NullableBoolC", "e"."NullableIntA", "e"."NullableIntB", "e"."NullableIntC", "e"."NullableStringA", "e"."NullableStringB", "e"."NullableStringC", "e"."StringA", "e"."StringB", "e"."StringC" +FROM "Entities1" AS "e" +WHERE CASE + WHEN "e"."StringA" = 'Foo' THEN 3 + WHEN "e"."StringB" = 'Foo' THEN 2 + WHEN "e"."StringC" = 'Foo' THEN 3 +END = 2 +"""); + } + + public override async Task CaseWhen_equal_to_first_or_third_filter(bool async) + { + await base.CaseWhen_equal_to_first_or_third_filter(async); + + AssertSql( + """ +SELECT "e"."Id", "e"."BoolA", "e"."BoolB", "e"."BoolC", "e"."IntA", "e"."IntB", "e"."IntC", "e"."NullableBoolA", "e"."NullableBoolB", "e"."NullableBoolC", "e"."NullableIntA", "e"."NullableIntB", "e"."NullableIntC", "e"."NullableStringA", "e"."NullableStringB", "e"."NullableStringC", "e"."StringA", "e"."StringB", "e"."StringC" +FROM "Entities1" AS "e" +WHERE CASE + WHEN "e"."StringA" = 'Foo' THEN 3 + WHEN "e"."StringB" = 'Foo' THEN 2 + WHEN "e"."StringC" = 'Foo' THEN 3 +END = 3 +"""); + } + + public override async Task CaseWhen_equal_to_second_select(bool async) + { + await base.CaseWhen_equal_to_second_select(async); + + AssertSql( + """ +SELECT CASE + WHEN "e"."StringA" = 'Foo' THEN 3 + WHEN "e"."StringB" = 'Foo' THEN 2 + WHEN "e"."StringC" = 'Foo' THEN 3 +END = 2 AND CASE + WHEN "e"."StringA" = 'Foo' THEN 3 + WHEN "e"."StringB" = 'Foo' THEN 2 + WHEN "e"."StringC" = 'Foo' THEN 3 +END IS NOT NULL +FROM "Entities1" AS "e" +ORDER BY "e"."Id" +"""); + } + + public override async Task CaseWhen_equal_to_first_or_third_select(bool async) + { + await base.CaseWhen_equal_to_first_or_third_select(async); + + AssertSql( + """ +SELECT CASE + WHEN "e"."StringA" = 'Foo' THEN 3 + WHEN "e"."StringB" = 'Foo' THEN 2 + WHEN "e"."StringC" = 'Foo' THEN 3 +END = 3 AND CASE + WHEN "e"."StringA" = 'Foo' THEN 3 + WHEN "e"."StringB" = 'Foo' THEN 2 + WHEN "e"."StringC" = 'Foo' THEN 3 +END IS NOT NULL +FROM "Entities1" AS "e" +ORDER BY "e"."Id" +"""); + } + public override async Task Bool_equal_nullable_bool_HasValue(bool async) { await base.Bool_equal_nullable_bool_HasValue(async);