-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix to #16092 - Query: Simplify case blocks in SQL tree
Adding optimization to during post processing Trying to match CASE block that corresponds to CompareTo translation. If that case block is compared to 0, 1 or -1 we can simplify it to simple comparison. Also added generic CASE block optimization, when constant is compared to CASE block, and that constant is one of the results Fixes #16092
- Loading branch information
Showing
8 changed files
with
618 additions
and
410 deletions.
There are no files selected for viewing
216 changes: 216 additions & 0 deletions
216
src/EFCore.Relational/Query/Internal/CaseSimplifyingExpressionVisitor.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
// Copyright (c) .NET Foundation. All rights reserved. | ||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. | ||
|
||
using System.Linq; | ||
using System.Linq.Expressions; | ||
using JetBrains.Annotations; | ||
using Microsoft.EntityFrameworkCore.Query.SqlExpressions; | ||
using Microsoft.EntityFrameworkCore.Utilities; | ||
|
||
namespace Microsoft.EntityFrameworkCore.Query.Internal | ||
{ | ||
/// <summary> | ||
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to | ||
/// the same compatibility standards as public APIs. It may be changed or removed without notice in | ||
/// any release. You should only use it directly in your code with extreme caution and knowing that | ||
/// doing so can result in application failures when updating to a new Entity Framework Core release. | ||
/// </summary> | ||
public class CaseSimplifyingExpressionVisitor : ExpressionVisitor | ||
{ | ||
private readonly ISqlExpressionFactory _sqlExpressionFactory; | ||
|
||
/// <summary> | ||
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to | ||
/// the same compatibility standards as public APIs. It may be changed or removed without notice in | ||
/// any release. You should only use it directly in your code with extreme caution and knowing that | ||
/// doing so can result in application failures when updating to a new Entity Framework Core release. | ||
/// </summary> | ||
public CaseSimplifyingExpressionVisitor([NotNull] ISqlExpressionFactory sqlExpressionFactory) | ||
{ | ||
_sqlExpressionFactory = sqlExpressionFactory; | ||
} | ||
|
||
/// <summary> | ||
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to | ||
/// the same compatibility standards as public APIs. It may be changed or removed without notice in | ||
/// any release. You should only use it directly in your code with extreme caution and knowing that | ||
/// doing so can result in application failures when updating to a new Entity Framework Core release. | ||
/// </summary> | ||
protected override Expression VisitExtension(Expression extensionExpression) | ||
{ | ||
Check.NotNull(extensionExpression, nameof(extensionExpression)); | ||
|
||
if (extensionExpression is ShapedQueryExpression shapedQueryExpression) | ||
{ | ||
return shapedQueryExpression.Update(Visit(shapedQueryExpression.QueryExpression), shapedQueryExpression.ShaperExpression); | ||
} | ||
|
||
// Only applies to 'CASE WHEN condition...' not 'CASE operand WHEN...' | ||
if (extensionExpression is CaseExpression caseExpression | ||
&& caseExpression.Operand == null | ||
&& caseExpression.ElseResult is CaseExpression nestedCaseExpression | ||
&& nestedCaseExpression.Operand == null) | ||
{ | ||
return VisitExtension(_sqlExpressionFactory.Case( | ||
caseExpression.WhenClauses.Union(nestedCaseExpression.WhenClauses).ToList(), | ||
nestedCaseExpression.ElseResult)); | ||
} | ||
|
||
if (extensionExpression is SqlBinaryExpression sqlBinaryExpression) | ||
{ | ||
var sqlConstantComponent = sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression; | ||
var caseComponent = sqlBinaryExpression.Left as CaseExpression ?? sqlBinaryExpression.Right as CaseExpression; | ||
|
||
// generic CASE statement comparison optimization: | ||
// (CASE | ||
// WHEN condition1 THEN result1 | ||
// WHEN condition2 THEN result2 | ||
// WHEN ... | ||
// WHEN conditionN THEN resultN) == result1 -> condition1 | ||
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal | ||
&& sqlConstantComponent != null | ||
&& sqlConstantComponent.Value != null | ||
&& caseComponent != null | ||
&& caseComponent.Operand == null) | ||
{ | ||
var matchingCaseBlock = caseComponent.WhenClauses.FirstOrDefault(wc => sqlConstantComponent.Equals(wc.Result)); | ||
if (matchingCaseBlock != null) | ||
{ | ||
return Visit(matchingCaseBlock.Test); | ||
} | ||
} | ||
|
||
// CompareTo specific optimizations | ||
if (sqlConstantComponent != null | ||
&& IsCompareTo(caseComponent) | ||
&& sqlConstantComponent.Value is int intValue | ||
&& (intValue > -2 && intValue < 2) | ||
&& (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual | ||
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThan | ||
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThanOrEqual | ||
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan | ||
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual)) | ||
{ | ||
return OptimizeCompareTo( | ||
sqlBinaryExpression, | ||
intValue, | ||
caseComponent); | ||
} | ||
} | ||
|
||
return base.VisitExtension(extensionExpression); | ||
} | ||
|
||
private bool IsCompareTo(CaseExpression caseExpression) | ||
{ | ||
if (caseExpression != null | ||
&& caseExpression.Operand == null | ||
&& caseExpression.ElseResult == null | ||
&& caseExpression.WhenClauses.Count == 3 | ||
&& caseExpression.WhenClauses.All(c => c.Test is SqlBinaryExpression | ||
&& c.Result is SqlConstantExpression constant | ||
&& constant.Value is int)) | ||
{ | ||
var whenClauses = caseExpression.WhenClauses.Select(c => new | ||
{ | ||
test = (SqlBinaryExpression)c.Test, | ||
resultValue = (int)((SqlConstantExpression)c.Result).Value | ||
}).ToList(); | ||
|
||
if (whenClauses[0].test.Left.Equals(whenClauses[1].test.Left) | ||
&& whenClauses[1].test.Left.Equals(whenClauses[2].test.Left) | ||
&& whenClauses[0].test.Right.Equals(whenClauses[1].test.Right) | ||
&& whenClauses[1].test.Right.Equals(whenClauses[2].test.Right) | ||
&& whenClauses[0].test.OperatorType == ExpressionType.Equal | ||
&& whenClauses[1].test.OperatorType == ExpressionType.GreaterThan | ||
&& whenClauses[2].test.OperatorType == ExpressionType.LessThan | ||
&& whenClauses[0].resultValue == 0 | ||
&& whenClauses[1].resultValue == 1 | ||
&& whenClauses[2].resultValue == -1) | ||
{ | ||
return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
|
||
private SqlExpression OptimizeCompareTo( | ||
SqlBinaryExpression sqlBinaryExpression, | ||
int intValue, | ||
CaseExpression caseExpression) | ||
{ | ||
var testLeft = ((SqlBinaryExpression)caseExpression.WhenClauses[0].Test).Left; | ||
var testRight = ((SqlBinaryExpression)caseExpression.WhenClauses[0].Test).Right; | ||
var operatorType = sqlBinaryExpression.Right is SqlConstantExpression | ||
? sqlBinaryExpression.OperatorType | ||
: sqlBinaryExpression.OperatorType switch | ||
{ | ||
ExpressionType.GreaterThan => ExpressionType.LessThan, | ||
ExpressionType.GreaterThanOrEqual => ExpressionType.LessThanOrEqual, | ||
ExpressionType.LessThan => ExpressionType.GreaterThan, | ||
ExpressionType.LessThanOrEqual => ExpressionType.GreaterThanOrEqual, | ||
_ => sqlBinaryExpression.OperatorType | ||
}; | ||
|
||
switch (operatorType) | ||
{ | ||
// CompareTo(a, b) != 0 -> a != b | ||
// CompareTo(a, b) != 1 -> a <= b | ||
// CompareTo(a, b) != -1 -> a >= b | ||
case ExpressionType.NotEqual: | ||
return (SqlExpression)Visit(intValue switch | ||
{ | ||
0 => _sqlExpressionFactory.NotEqual(testLeft, testRight), | ||
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), | ||
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight), | ||
}); | ||
|
||
// CompareTo(a, b) > 0 -> a > b | ||
// CompareTo(a, b) > 1 -> false | ||
// CompareTo(a, b) > -1 -> a >= b | ||
case ExpressionType.GreaterThan: | ||
return (SqlExpression)Visit(intValue switch | ||
{ | ||
0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), | ||
1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping), | ||
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight), | ||
}); | ||
|
||
// CompareTo(a, b) >= 0 -> a >= b | ||
// CompareTo(a, b) >= 1 -> a > b | ||
// CompareTo(a, b) >= -1 -> true | ||
case ExpressionType.GreaterThanOrEqual: | ||
return (SqlExpression)Visit(intValue switch | ||
{ | ||
0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight), | ||
1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight), | ||
_ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping), | ||
}); | ||
|
||
// CompareTo(a, b) < 0 -> a < b | ||
// CompareTo(a, b) < 1 -> a <= b | ||
// CompareTo(a, b) < -1 -> false | ||
case ExpressionType.LessThan: | ||
return (SqlExpression)Visit(intValue switch | ||
{ | ||
0 => _sqlExpressionFactory.LessThan(testLeft, testRight), | ||
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), | ||
_ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping), | ||
}); | ||
|
||
// operatorType == ExpressionType.LessThanOrEqual | ||
// CompareTo(a, b) <= 0 -> a <= b | ||
// CompareTo(a, b) <= 1 -> true | ||
// CompareTo(a, b) <= -1 -> a < b | ||
default: | ||
return (SqlExpression)Visit(intValue switch | ||
{ | ||
0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight), | ||
1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping), | ||
_ => _sqlExpressionFactory.LessThan(testLeft, testRight), | ||
}); | ||
}; | ||
} | ||
} | ||
} |
62 changes: 0 additions & 62 deletions
62
src/EFCore.Relational/Query/Internal/CaseWhenFlatteningExpressionVisitor.cs
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.