Skip to content

Commit

Permalink
Fix to #16092 - Query: Simplify case blocks in SQL tree
Browse files Browse the repository at this point in the history
Adding optimization to during post processing
Trying to match CASE block that corresponds to CompareTo translation. If that case block is compared to 0, 1 or -1 we can simplify it to simple comparison.
Also added generic CASE block optimization, when constant is compared to CASE block, and that constant is one of the results

Fixes #16092
  • Loading branch information
maumar committed May 15, 2020
1 parent 7f81b60 commit c643051
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 353 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,164 @@ protected override Expression VisitExtension(Expression extensionExpression)
nestedCaseExpression.ElseResult));
}

if (extensionExpression is SqlBinaryExpression sqlBinaryExpression)
{
var sqlConstantComponent = sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression;
var caseComponent = sqlBinaryExpression.Left as CaseExpression ?? sqlBinaryExpression.Right as CaseExpression;

// generic CASE statement comparison optimization:
// (CASE
// WHEN condition1 THEN result1
// WHEN condition2 THEN result2
// WHEN ...
// WHEN conditionN THEN resultN) == result1 -> condition1
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal
&& sqlConstantComponent != null
&& sqlConstantComponent.Value != null
&& caseComponent != null
&& caseComponent.Operand == null)
{
var matchingCaseBlock = caseComponent.WhenClauses.FirstOrDefault(wc => sqlConstantComponent.Equals(wc.Result));
if (matchingCaseBlock != null)
{
return matchingCaseBlock.Test;
}
}

// CompareTo specific optimizations
if (sqlConstantComponent != null
&& sqlConstantComponent.Value != null
&& IsCompareTo(caseComponent)
&& sqlConstantComponent.Value is int intValue
&& (intValue == 0 || intValue == 1 || intValue == -1)
&& (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThan
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThanOrEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual))
{
return OptimizeCompareTo(
sqlBinaryExpression,
intValue,
caseComponent);
}
}

return base.VisitExtension(extensionExpression);
}

private bool IsCompareTo(CaseExpression caseExpression)
{
if (caseExpression != null
&& caseExpression.Operand == null
&& caseExpression.WhenClauses.Count == 3
&& caseExpression.WhenClauses.All(c => c.Test is SqlBinaryExpression
&& c.Result is SqlConstantExpression constant
&& constant.Value is int))
{
var whenClauses = caseExpression.WhenClauses.Select(c => new
{
test = (SqlBinaryExpression)c.Test,
resultValue = (int)((SqlConstantExpression)c.Result).Value
}).ToList();

if (whenClauses[0].test.Left.Equals(whenClauses[1].test.Left)
&& whenClauses[1].test.Left.Equals(whenClauses[2].test.Left)
&& whenClauses[0].test.Right.Equals(whenClauses[1].test.Right)
&& whenClauses[1].test.Right.Equals(whenClauses[2].test.Right)
&& whenClauses[0].test.OperatorType == ExpressionType.Equal
&& whenClauses[1].test.OperatorType == ExpressionType.GreaterThan
&& whenClauses[2].test.OperatorType == ExpressionType.LessThan
&& whenClauses[0].resultValue == 0
&& whenClauses[1].resultValue == 1
&& whenClauses[2].resultValue == -1)
{
return true;
}
}

return false;
}

private SqlExpression OptimizeCompareTo(
SqlBinaryExpression sqlBinaryExpression,
int intValue,
CaseExpression caseExpression)
{
var testLeft = ((SqlBinaryExpression)caseExpression.WhenClauses[0].Test).Left;
var testRight = ((SqlBinaryExpression)caseExpression.WhenClauses[0].Test).Right;
var operatorType = sqlBinaryExpression.Right is SqlConstantExpression
? sqlBinaryExpression.OperatorType
: sqlBinaryExpression.OperatorType switch
{
ExpressionType.GreaterThan => ExpressionType.LessThan,
ExpressionType.GreaterThanOrEqual => ExpressionType.LessThanOrEqual,
ExpressionType.LessThan => ExpressionType.GreaterThan,
ExpressionType.LessThanOrEqual => ExpressionType.GreaterThanOrEqual,
_ => sqlBinaryExpression.OperatorType
};

if (operatorType == ExpressionType.NotEqual)
{
// CompareTo(a, b) != 0 -> a != b
// CompareTo(a, b) != 1 -> a <= b
// CompareTo(a, b) != -1 -> a >= b
return intValue switch
{
0 => _sqlExpressionFactory.NotEqual(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
};
}
else if (operatorType == ExpressionType.GreaterThan)
{
// CompareTo(a, b) > 0 -> a > b
// CompareTo(a, b) > 1 -> false
// CompareTo(a, b) > -1 -> a >= b
return intValue switch
{
0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
};
}
else if (operatorType == ExpressionType.GreaterThanOrEqual)
{
// CompareTo(a, b) >= 0 -> a >= b
// CompareTo(a, b) >= 1 -> a > b
// CompareTo(a, b) >= -1 -> true
return intValue switch
{
0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping),
};
}
else if (operatorType == ExpressionType.LessThan)
{
// CompareTo(a, b) < 0 -> a < b
// CompareTo(a, b) < 1 -> a <= b
// CompareTo(a, b) < -1 -> false
return intValue switch
{
0 => _sqlExpressionFactory.LessThan(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping),
};
}
else
{
// operatorType == ExpressionType.LessThanOrEqual
// CompareTo(a, b) <= 0 -> a <= b
// CompareTo(a, b) <= 1 -> true
// CompareTo(a, b) <= -1 -> a < b
return intValue switch
{
0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.LessThan(testLeft, testRight),
};
}
}
}
}
13 changes: 6 additions & 7 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -751,14 +751,14 @@ protected virtual SqlExpression VisitScalarSubquery(

return scalarSubqueryExpression.Update(Visit(scalarSubqueryExpression.Subquery));
}

/// <summary>
/// Visits a <see cref="SqlBinaryExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlBinaryExpression"> A sql binary expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlBinary(
[NotNull] SqlBinaryExpression sqlBinaryExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand Down Expand Up @@ -862,7 +862,6 @@ protected virtual SqlExpression VisitSqlBinary(
}

nullable = leftNullable || rightNullable;

var result = sqlBinaryExpression.Update(left, right);

return result is SqlBinaryExpression sqlBinaryResult
Expand All @@ -876,14 +875,14 @@ SqlExpression AddNullConcatenationProtection(SqlExpression argument, RelationalT
? (SqlExpression)_sqlExpressionFactory.Constant(string.Empty, typeMapping)
: _sqlExpressionFactory.Coalesce(argument, _sqlExpressionFactory.Constant(string.Empty, typeMapping));
}

/// <summary>
/// Visits a <see cref="SqlConstantExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlConstantExpression"> A sql constant expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlConstant(
[NotNull] SqlConstantExpression sqlConstantExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand All @@ -893,14 +892,14 @@ protected virtual SqlExpression VisitSqlConstant(

return sqlConstantExpression;
}

/// <summary>
/// Visits a <see cref="SqlFragmentExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlFragmentExpression"> A sql fragment expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlFragment(
[NotNull] SqlFragmentExpression sqlFragmentExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand All @@ -910,14 +909,14 @@ protected virtual SqlExpression VisitSqlFragment(

return sqlFragmentExpression;
}

/// <summary>
/// Visits a <see cref="SqlFunctionExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlFunctionExpression"> A sql function expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlFunction(
[NotNull] SqlFunctionExpression sqlFunctionExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand Down Expand Up @@ -951,14 +950,14 @@ protected virtual SqlExpression VisitSqlFunction(

return sqlFunctionExpression.Update(instance, arguments);
}

/// <summary>
/// Visits a <see cref="SqlParameterExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlParameterExpression"> A sql parameter expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlParameter(
[NotNull] SqlParameterExpression sqlParameterExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand All @@ -970,14 +969,14 @@ protected virtual SqlExpression VisitSqlParameter(
? _sqlExpressionFactory.Constant(null, sqlParameterExpression.TypeMapping)
: (SqlExpression)sqlParameterExpression;
}

/// <summary>
/// Visits a <see cref="SqlUnaryExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlUnaryExpression"> A sql unary expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlUnary(
[NotNull] SqlUnaryExpression sqlUnaryExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,21 @@ public virtual Task Nullable_string_FirstOrDefault_compared_to_nullable_string_L
== e.NullableStringB.MaybeScalar(x => x.LastOrDefault())));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Null_semantics_applied_to_CompareTo_equality(bool async)
{
await AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA.CompareTo(e.NullableStringB) == 0),
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA == e.NullableStringB));

await AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA.CompareTo(e.NullableStringB) != 0),
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA != e.NullableStringB));
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5248,20 +5248,11 @@ public override async Task Double_order_by_on_string_compare(bool async)
{
await base.Double_order_by_on_string_compare(async);

// issue #16092
AssertSql(
@"SELECT [w].[Id], [w].[AmmunitionType], [w].[IsAutomatic], [w].[Name], [w].[OwnerFullName], [w].[SynergyWithId]
FROM [Weapons] AS [w]
ORDER BY CASE
WHEN (CASE
WHEN [w].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w].[Name] < N'Marcus'' Lancer' THEN -1
END = 0) AND CASE
WHEN [w].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w].[Name] < N'Marcus'' Lancer' THEN -1
END IS NOT NULL THEN CAST(1 AS bit)
WHEN ([w].[Name] = N'Marcus'' Lancer') AND [w].[Name] IS NOT NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END, [w].[Id]");
}
Expand All @@ -5280,21 +5271,12 @@ public override async Task String_compare_with_null_conditional_argument(bool as
{
await base.String_compare_with_null_conditional_argument(async);

// issue #16092
AssertSql(
@"SELECT [w0].[Id], [w0].[AmmunitionType], [w0].[IsAutomatic], [w0].[Name], [w0].[OwnerFullName], [w0].[SynergyWithId]
FROM [Weapons] AS [w]
LEFT JOIN [Weapons] AS [w0] ON [w].[SynergyWithId] = [w0].[Id]
ORDER BY CASE
WHEN (CASE
WHEN [w0].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w0].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w0].[Name] < N'Marcus'' Lancer' THEN -1
END = 0) AND CASE
WHEN [w0].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w0].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w0].[Name] < N'Marcus'' Lancer' THEN -1
END IS NOT NULL THEN CAST(1 AS bit)
WHEN ([w0].[Name] = N'Marcus'' Lancer') AND [w0].[Name] IS NOT NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END");
}
Expand All @@ -5303,21 +5285,12 @@ public override async Task String_compare_with_null_conditional_argument2(bool a
{
await base.String_compare_with_null_conditional_argument2(async);

// issue #16092
AssertSql(
@"SELECT [w0].[Id], [w0].[AmmunitionType], [w0].[IsAutomatic], [w0].[Name], [w0].[OwnerFullName], [w0].[SynergyWithId]
FROM [Weapons] AS [w]
LEFT JOIN [Weapons] AS [w0] ON [w].[SynergyWithId] = [w0].[Id]
ORDER BY CASE
WHEN (CASE
WHEN N'Marcus'' Lancer' = [w0].[Name] THEN 0
WHEN N'Marcus'' Lancer' > [w0].[Name] THEN 1
WHEN N'Marcus'' Lancer' < [w0].[Name] THEN -1
END = 0) AND CASE
WHEN N'Marcus'' Lancer' = [w0].[Name] THEN 0
WHEN N'Marcus'' Lancer' > [w0].[Name] THEN 1
WHEN N'Marcus'' Lancer' < [w0].[Name] THEN -1
END IS NOT NULL THEN CAST(1 AS bit)
WHEN (N'Marcus'' Lancer' = [w0].[Name]) AND [w0].[Name] IS NOT NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END");
}
Expand Down
Loading

0 comments on commit c643051

Please sign in to comment.