Skip to content

Commit

Permalink
Fix to #16092 - Query: Simplify case blocks in SQL tree
Browse files Browse the repository at this point in the history
Adding optimization to during post processing (null semantics)
Trying to match CASE block that corresponds to CompareTo translation. If that case block is compared to 0, 1 or -1 we can simplify it to simple comparison.
Also added generic CASE block optimization, when constant is compared to CASE block, and that constant is one of the results

Fixes #16092
  • Loading branch information
maumar committed May 15, 2020
1 parent 7f81b60 commit 0f36c1d
Show file tree
Hide file tree
Showing 5 changed files with 515 additions and 353 deletions.
186 changes: 179 additions & 7 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -751,19 +751,52 @@ protected virtual SqlExpression VisitScalarSubquery(

return scalarSubqueryExpression.Update(Visit(scalarSubqueryExpression.Subquery));
}

/// <summary>
/// Visits a <see cref="SqlBinaryExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlBinaryExpression"> A sql binary expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlBinary(
[NotNull] SqlBinaryExpression sqlBinaryExpression, bool allowOptimizedExpansion, out bool nullable)
{
Check.NotNull(sqlBinaryExpression, nameof(sqlBinaryExpression));

// we need to do this before we visit left/right
// otherwise detecting CompareTo block becomes hard due to null semantics expansion
// also we need to apply null semantics on the potential result
var sqlConstantExpression = sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression;
var caseExpression = sqlBinaryExpression.Left as CaseExpression ?? sqlBinaryExpression.Right as CaseExpression;
if (sqlConstantExpression != null
&& sqlConstantExpression.Value != null
&& IsCompareTo(caseExpression)
&& sqlConstantExpression.Value is int intValue
&& (intValue == 0 || intValue == 1 || intValue == -1)
&& (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThan
|| sqlBinaryExpression.OperatorType == ExpressionType.GreaterThanOrEqual
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual))
{
var compareToOptimized = OptimizeCompareTo(
sqlBinaryExpression,
intValue,
caseExpression);

if (compareToOptimized is SqlConstantExpression)
{
nullable = true;

return compareToOptimized;
}
else
{
sqlBinaryExpression = (SqlBinaryExpression)compareToOptimized;
}
}

var optimize = allowOptimizedExpansion;

allowOptimizedExpansion = allowOptimizedExpansion
Expand Down Expand Up @@ -862,7 +895,6 @@ protected virtual SqlExpression VisitSqlBinary(
}

nullable = leftNullable || rightNullable;

var result = sqlBinaryExpression.Update(left, right);

return result is SqlBinaryExpression sqlBinaryResult
Expand All @@ -876,14 +908,128 @@ SqlExpression AddNullConcatenationProtection(SqlExpression argument, RelationalT
? (SqlExpression)_sqlExpressionFactory.Constant(string.Empty, typeMapping)
: _sqlExpressionFactory.Coalesce(argument, _sqlExpressionFactory.Constant(string.Empty, typeMapping));
}

private bool IsCompareTo(CaseExpression caseExpression)
{
if (caseExpression != null
&& caseExpression.Operand == null
&& caseExpression.WhenClauses.Count == 3
&& caseExpression.WhenClauses.All(c => c.Test is SqlBinaryExpression
&& c.Result is SqlConstantExpression constant
&& constant.Value is int))
{
var whenClauses = caseExpression.WhenClauses.Select(c => new
{
test = (SqlBinaryExpression)c.Test,
resultValue = (int)((SqlConstantExpression)c.Result).Value
}).ToList();

if (whenClauses[0].test.Left.Equals(whenClauses[1].test.Left)
&& whenClauses[1].test.Left.Equals(whenClauses[2].test.Left)
&& whenClauses[0].test.Right.Equals(whenClauses[1].test.Right)
&& whenClauses[1].test.Right.Equals(whenClauses[2].test.Right)
&& whenClauses[0].test.OperatorType == ExpressionType.Equal
&& whenClauses[1].test.OperatorType == ExpressionType.GreaterThan
&& whenClauses[2].test.OperatorType == ExpressionType.LessThan
&& whenClauses[0].resultValue == 0
&& whenClauses[1].resultValue == 1
&& whenClauses[2].resultValue == -1)
{
return true;
}
}

return false;
}

private SqlExpression OptimizeCompareTo(
SqlBinaryExpression sqlBinaryExpression,
int intValue,
CaseExpression caseExpression)
{
var testLeft = ((SqlBinaryExpression)caseExpression.WhenClauses[0].Test).Left;
var testRight = ((SqlBinaryExpression)caseExpression.WhenClauses[0].Test).Right;
var operatorType = sqlBinaryExpression.Right is SqlConstantExpression
? sqlBinaryExpression.OperatorType
: sqlBinaryExpression.OperatorType switch
{
ExpressionType.GreaterThan => ExpressionType.LessThan,
ExpressionType.GreaterThanOrEqual => ExpressionType.LessThanOrEqual,
ExpressionType.LessThan => ExpressionType.GreaterThan,
ExpressionType.LessThanOrEqual => ExpressionType.GreaterThanOrEqual,
_ => sqlBinaryExpression.OperatorType
};

if (operatorType == ExpressionType.NotEqual)
{
// CompareTo(a, b) != 0 -> a != b
// CompareTo(a, b) != 1 -> a <= b
// CompareTo(a, b) != -1 -> a >= b
return intValue switch
{
0 => _sqlExpressionFactory.NotEqual(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
};
}
else if (operatorType == ExpressionType.GreaterThan)
{
// CompareTo(a, b) > 0 -> a > b
// CompareTo(a, b) > 1 -> false
// CompareTo(a, b) > -1 -> a >= b
return intValue switch
{
0 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
};
}
else if (operatorType == ExpressionType.GreaterThanOrEqual)
{
// CompareTo(a, b) >= 0 -> a >= b
// CompareTo(a, b) >= 1 -> a > b
// CompareTo(a, b) >= -1 -> true
return intValue switch
{
0 => _sqlExpressionFactory.GreaterThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.GreaterThan(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping),
};
}
else if (operatorType == ExpressionType.LessThan)
{
// CompareTo(a, b) < 0 -> a < b
// CompareTo(a, b) < 1 -> a <= b
// CompareTo(a, b) < -1 -> false
return intValue switch
{
0 => _sqlExpressionFactory.LessThan(testLeft, testRight),
1 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
_ => _sqlExpressionFactory.Constant(false, sqlBinaryExpression.TypeMapping),
};
}
else
{
// operatorType == ExpressionType.LessThanOrEqual
// CompareTo(a, b) <= 0 -> a <= b
// CompareTo(a, b) <= 1 -> true
// CompareTo(a, b) <= -1 -> a < b
return intValue switch
{
0 => _sqlExpressionFactory.LessThanOrEqual(testLeft, testRight),
1 => _sqlExpressionFactory.Constant(true, sqlBinaryExpression.TypeMapping),
_ => _sqlExpressionFactory.LessThan(testLeft, testRight),
};
}
}

/// <summary>
/// Visits a <see cref="SqlConstantExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlConstantExpression"> A sql constant expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlConstant(
[NotNull] SqlConstantExpression sqlConstantExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand All @@ -893,14 +1039,14 @@ protected virtual SqlExpression VisitSqlConstant(

return sqlConstantExpression;
}

/// <summary>
/// Visits a <see cref="SqlFragmentExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlFragmentExpression"> A sql fragment expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlFragment(
[NotNull] SqlFragmentExpression sqlFragmentExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand All @@ -910,14 +1056,14 @@ protected virtual SqlExpression VisitSqlFragment(

return sqlFragmentExpression;
}

/// <summary>
/// Visits a <see cref="SqlFunctionExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlFunctionExpression"> A sql function expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlFunction(
[NotNull] SqlFunctionExpression sqlFunctionExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand Down Expand Up @@ -951,14 +1097,14 @@ protected virtual SqlExpression VisitSqlFunction(

return sqlFunctionExpression.Update(instance, arguments);
}

/// <summary>
/// Visits a <see cref="SqlParameterExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlParameterExpression"> A sql parameter expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlParameter(
[NotNull] SqlParameterExpression sqlParameterExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand All @@ -970,14 +1116,14 @@ protected virtual SqlExpression VisitSqlParameter(
? _sqlExpressionFactory.Constant(null, sqlParameterExpression.TypeMapping)
: (SqlExpression)sqlParameterExpression;
}

/// <summary>
/// Visits a <see cref="SqlUnaryExpression"/> and computes its nullability.
/// </summary>
/// <param name="sqlUnaryExpression"> A sql unary expression to visit. </param>
/// <param name="allowOptimizedExpansion"> A bool value indicating if optimized expansion which considers null value as false value is allowed. </param>
/// <param name="nullable"> A bool value indicating whether the sql expression is nullable. </param>
/// <returns> An optimized sql expression. </returns>

protected virtual SqlExpression VisitSqlUnary(
[NotNull] SqlUnaryExpression sqlUnaryExpression, bool allowOptimizedExpansion, out bool nullable)
{
Expand Down Expand Up @@ -1163,6 +1309,32 @@ private SqlExpression OptimizeComparison(
: _sqlExpressionFactory.Equal(left, right);
}

var sqlConstantExpression = left as SqlConstantExpression ?? right as SqlConstantExpression;
var caseExpression = left as CaseExpression ?? right as CaseExpression;

// generic CASE statement comparison optimization:
// (CASE
// WHEN condition1 THEN result1
// WHEN condition2 THEN result2
// WHEN ...
// WHEN conditionN THEN resultN) == result1 -> condition1
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal
&& sqlConstantExpression != null
&& sqlConstantExpression.Value != null
&& caseExpression != null
&& caseExpression.Operand == null)
{
var matchingCaseBlock = caseExpression.WhenClauses.FirstOrDefault(wc => sqlConstantExpression.Equals(wc.Result));
if (matchingCaseBlock != null)
{
// we don't know if it's nullable since we don't store nullability of specific fragments
// so we must assume it's nullable
nullable = true;

return matchingCaseBlock.Test;
}
}

nullable = false;

return sqlBinaryExpression.Update(left, right);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,21 @@ public virtual Task Nullable_string_FirstOrDefault_compared_to_nullable_string_L
== e.NullableStringB.MaybeScalar(x => x.LastOrDefault())));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Null_semantics_applied_to_CompareTo_equality(bool async)
{
await AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA.CompareTo(e.NullableStringB) == 0),
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA == e.NullableStringB));

await AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA.CompareTo(e.NullableStringB) != 0),
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableStringA != e.NullableStringB));
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5248,20 +5248,11 @@ public override async Task Double_order_by_on_string_compare(bool async)
{
await base.Double_order_by_on_string_compare(async);

// issue #16092
AssertSql(
@"SELECT [w].[Id], [w].[AmmunitionType], [w].[IsAutomatic], [w].[Name], [w].[OwnerFullName], [w].[SynergyWithId]
FROM [Weapons] AS [w]
ORDER BY CASE
WHEN (CASE
WHEN [w].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w].[Name] < N'Marcus'' Lancer' THEN -1
END = 0) AND CASE
WHEN [w].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w].[Name] < N'Marcus'' Lancer' THEN -1
END IS NOT NULL THEN CAST(1 AS bit)
WHEN [w].[Name] = N'Marcus'' Lancer' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END, [w].[Id]");
}
Expand All @@ -5280,21 +5271,12 @@ public override async Task String_compare_with_null_conditional_argument(bool as
{
await base.String_compare_with_null_conditional_argument(async);

// issue #16092
AssertSql(
@"SELECT [w0].[Id], [w0].[AmmunitionType], [w0].[IsAutomatic], [w0].[Name], [w0].[OwnerFullName], [w0].[SynergyWithId]
FROM [Weapons] AS [w]
LEFT JOIN [Weapons] AS [w0] ON [w].[SynergyWithId] = [w0].[Id]
ORDER BY CASE
WHEN (CASE
WHEN [w0].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w0].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w0].[Name] < N'Marcus'' Lancer' THEN -1
END = 0) AND CASE
WHEN [w0].[Name] = N'Marcus'' Lancer' THEN 0
WHEN [w0].[Name] > N'Marcus'' Lancer' THEN 1
WHEN [w0].[Name] < N'Marcus'' Lancer' THEN -1
END IS NOT NULL THEN CAST(1 AS bit)
WHEN [w0].[Name] = N'Marcus'' Lancer' THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END");
}
Expand All @@ -5303,21 +5285,12 @@ public override async Task String_compare_with_null_conditional_argument2(bool a
{
await base.String_compare_with_null_conditional_argument2(async);

// issue #16092
AssertSql(
@"SELECT [w0].[Id], [w0].[AmmunitionType], [w0].[IsAutomatic], [w0].[Name], [w0].[OwnerFullName], [w0].[SynergyWithId]
FROM [Weapons] AS [w]
LEFT JOIN [Weapons] AS [w0] ON [w].[SynergyWithId] = [w0].[Id]
ORDER BY CASE
WHEN (CASE
WHEN N'Marcus'' Lancer' = [w0].[Name] THEN 0
WHEN N'Marcus'' Lancer' > [w0].[Name] THEN 1
WHEN N'Marcus'' Lancer' < [w0].[Name] THEN -1
END = 0) AND CASE
WHEN N'Marcus'' Lancer' = [w0].[Name] THEN 0
WHEN N'Marcus'' Lancer' > [w0].[Name] THEN 1
WHEN N'Marcus'' Lancer' < [w0].[Name] THEN -1
END IS NOT NULL THEN CAST(1 AS bit)
WHEN N'Marcus'' Lancer' = [w0].[Name] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END");
}
Expand Down
Loading

0 comments on commit 0f36c1d

Please sign in to comment.