Skip to content

Commit

Permalink
Fixes and improvements to StartsWith/EndsWith/Contains (#31482)
Browse files Browse the repository at this point in the history
Closes #30493
Closes #11881
Closes #26735
  • Loading branch information
roji authored Aug 16, 2023
1 parent 3cf064e commit a07a1bd
Show file tree
Hide file tree
Showing 69 changed files with 1,351 additions and 905 deletions.
1 change: 1 addition & 0 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ protected virtual void GenerateLike(LikeExpression likeExpression, bool negated)
}

_relationalCommandBuilder.Append(" LIKE ");

Visit(likeExpression.Pattern);

if (likeExpression.EscapeChar != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio
/// <inheritdoc />
protected override Expression VisitParameter(ParameterExpression parameterExpression)
=> parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) == true
? new SqlParameterExpression(parameterExpression, null)
? new SqlParameterExpression(parameterExpression.Name, parameterExpression.Type, null)
: throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print()));

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,30 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;
/// <summary>
/// An expression that represents a parameter in a SQL tree.
/// </summary>
/// <remarks>
/// This is a simple wrapper around a <see cref="ParameterExpression" /> in the SQL tree.
/// Instances of this type cannot be constructed by application or database provider code. If this is a problem for your
/// application or provider, then please file an issue at
/// <see href="https://github.com/dotnet/efcore">github.com/dotnet/efcore</see>.
/// </remarks>
public sealed class SqlParameterExpression : SqlExpression
{
private readonly ParameterExpression _parameterExpression;
private readonly string _name;

internal SqlParameterExpression(ParameterExpression parameterExpression, RelationalTypeMapping? typeMapping)
: base(parameterExpression.Type.UnwrapNullableType(), typeMapping)
/// <summary>
/// Creates a new instance of the <see cref="SqlParameterExpression" /> class.
/// </summary>
/// <param name="name">The parameter name.</param>
/// <param name="type">The <see cref="Type" /> of the expression.</param>
/// <param name="typeMapping">The <see cref="RelationalTypeMapping" /> associated with the expression.</param>
public SqlParameterExpression(string name, Type type, RelationalTypeMapping? typeMapping)
: this(name, type.UnwrapNullableType(), type.IsNullableType(), typeMapping)
{
Check.DebugAssert(parameterExpression.Name != null, "Parameter must have name.");
}

_parameterExpression = parameterExpression;
_name = parameterExpression.Name;
IsNullable = parameterExpression.Type.IsNullableType();
private SqlParameterExpression(string name, Type type, bool nullable, RelationalTypeMapping? typeMapping)
: base(type, typeMapping)
{
Name = name;
IsNullable = nullable;
}

/// <summary>
/// The name of the parameter.
/// </summary>
public string Name
=> _name;
public string Name { get; }

/// <summary>
/// The bool value indicating if this parameter can have null values.
Expand All @@ -44,15 +42,15 @@ public string Name
/// <param name="typeMapping">A relational type mapping to apply.</param>
/// <returns>A new expression which has supplied type mapping.</returns>
public SqlExpression ApplyTypeMapping(RelationalTypeMapping? typeMapping)
=> new SqlParameterExpression(_parameterExpression, typeMapping);
=> new SqlParameterExpression(Name, Type, IsNullable, typeMapping);

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> this;

/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
=> expressionPrinter.Append("@" + _parameterExpression.Name);
=> expressionPrinter.Append("@" + Name);

/// <inheritdoc />
public override bool Equals(object? obj)
Expand Down
124 changes: 96 additions & 28 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ protected virtual TableExpressionBase Visit(TableExpressionBase tableExpressionB
var newTable = Visit(innerJoinExpression.Table);
var newJoinPredicate = ProcessJoinPredicate(innerJoinExpression.JoinPredicate);

return TryGetBoolConstantValue(newJoinPredicate) == true
return IsTrue(newJoinPredicate)
? new CrossJoinExpression(newTable)
: innerJoinExpression.Update(newTable, newJoinPredicate);
}
Expand Down Expand Up @@ -301,7 +301,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
var predicate = Visit(selectExpression.Predicate, allowOptimizedExpansion: true, out _);
changed |= predicate != selectExpression.Predicate;

if (TryGetBoolConstantValue(predicate) == true)
if (IsTrue(predicate))
{
predicate = null;
changed = true;
Expand Down Expand Up @@ -333,7 +333,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
var having = Visit(selectExpression.Having, allowOptimizedExpansion: true, out _);
changed |= having != selectExpression.Having;

if (TryGetBoolConstantValue(having) == true)
if (IsTrue(having))
{
having = null;
changed = true;
Expand Down Expand Up @@ -519,20 +519,17 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
var test = Visit(
whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);

if (TryGetBoolConstantValue(test) is bool testConstantBool)
if (IsTrue(test))
{
if (testConstantBool)
{
testEvaluatesToTrue = true;
}
else
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNullValueColumnsCount);
testEvaluatesToTrue = true;
}
else if (IsFalse(test))
{
// if test evaluates to 'false' we can remove the WhenClause
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
RestoreNullValueColumnsList(currentNullValueColumnsCount);

continue;
}
continue;
}

var newResult = Visit(whenClause.Result, out var resultNullable);
Expand Down Expand Up @@ -570,7 +567,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
// if there is only one When clause and it's test evaluates to 'true' AND there is no else block, simply return the result
return elseResult == null
&& whenClauses.Count == 1
&& TryGetBoolConstantValue(whenClauses[0].Test) == true
&& IsTrue(whenClauses[0].Test)
? whenClauses[0].Result
: caseExpression.Update(operand, whenClauses, elseResult);
}
Expand Down Expand Up @@ -635,7 +632,7 @@ protected virtual SqlExpression VisitExists(

// if subquery has predicate which evaluates to false, we can simply return false
// if the exists is negated we need to return true instead
return TryGetBoolConstantValue(subquery.Predicate) == false
return IsFalse(subquery.Predicate)
? _sqlExpressionFactory.Constant(false, existsExpression.TypeMapping)
: existsExpression.Update(subquery);
}
Expand All @@ -658,7 +655,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
var subquery = Visit(inExpression.Subquery);

// a IN (SELECT * FROM table WHERE false) => false
if (TryGetBoolConstantValue(subquery.Predicate) == false)
if (IsFalse(subquery.Predicate))
{
nullable = false;

Expand Down Expand Up @@ -967,9 +964,64 @@ protected virtual SqlExpression VisitLike(LikeExpression likeExpression, bool al
var pattern = Visit(likeExpression.Pattern, out var patternNullable);
var escapeChar = Visit(likeExpression.EscapeChar, out var escapeCharNullable);

nullable = matchNullable || patternNullable || escapeCharNullable;
SqlExpression result = likeExpression.Update(match, pattern, escapeChar);

if (UseRelationalNulls)
{
nullable = matchNullable || patternNullable || escapeCharNullable;

return result;
}

nullable = false;

// The null semantics behavior we implement for LIKE is that it only returns true when both sides are non-null and match; any other
// input returns false:
// foo LIKE f% -> true
// foo LIKE null -> false
// null LIKE f% -> false
// null LIKE null -> false

if (IsNull(match) || IsNull(pattern) || IsNull(escapeChar))
{
return _sqlExpressionFactory.Constant(false, likeExpression.TypeMapping);
}

// A constant match-all pattern (%) returns true for all cases, except where the match string is null:
// nullable_foo LIKE % -> foo IS NOT NULL
// non_nullable_foo LIKE % -> true
if (pattern is SqlConstantExpression { Value: "%" })
{
return matchNullable
? _sqlExpressionFactory.IsNotNull(match)
: _sqlExpressionFactory.Constant(true, likeExpression.TypeMapping);
}

return likeExpression.Update(match, pattern, escapeChar);
if (!allowOptimizedExpansion)
{
if (matchNullable)
{
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(match));
}

if (patternNullable)
{
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(pattern));
}

if (escapeChar is not null && escapeCharNullable)
{
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(escapeChar));
}
}

return result;

SqlExpression GenerateNotNullCheck(SqlExpression operand)
=> OptimizeNonNullableNotExpression(
_sqlExpressionFactory.Not(
ProcessNullNotNull(
_sqlExpressionFactory.IsNull(operand), operandNullable: true)));
}

/// <summary>
Expand Down Expand Up @@ -1395,8 +1447,28 @@ protected virtual SqlExpression VisitJsonScalar(
/// </summary>
protected virtual bool PreferExistsToComplexIn => false;

private static bool? TryGetBoolConstantValue(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: bool boolValue } ? boolValue : null;
// Note that we can check parameter values for null since we cache by the parameter nullability; but we cannot do the same for bool.
private bool IsNull(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: null }
|| expression is SqlParameterExpression { Name: string parameterName } && ParameterValues[parameterName] is null;

private bool IsTrue(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: true };

private bool IsFalse(SqlExpression? expression)
=> expression is SqlConstantExpression { Value: false };

private bool TryGetBool(SqlExpression? expression, out bool value)
{
if (expression is SqlConstantExpression { Value: bool b })
{
value = b;
return true;
}

value = false;
return false;
}

private void RestoreNonNullableColumnsList(int counter)
{
Expand Down Expand Up @@ -1486,7 +1558,7 @@ private SqlExpression OptimizeComparison(
return result;
}

if (TryGetBoolConstantValue(right) is bool rightBoolValue
if (TryGetBool(right, out var rightBoolValue)
&& !leftNullable
&& left.TypeMapping!.Converter == null)
{
Expand All @@ -1502,7 +1574,7 @@ private SqlExpression OptimizeComparison(
: left;
}

if (TryGetBoolConstantValue(left) is bool leftBoolValue
if (TryGetBool(left, out var leftBoolValue)
&& !rightNullable
&& right.TypeMapping!.Converter == null)
{
Expand Down Expand Up @@ -2069,10 +2141,6 @@ private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression,
private static bool IsLogicalNot(SqlUnaryExpression? sqlUnaryExpression)
=> sqlUnaryExpression is { OperatorType: ExpressionType.Not } && sqlUnaryExpression.Type == typeof(bool);

private bool IsNull(SqlExpression expression)
=> expression is SqlConstantExpression { Value: null }
|| expression is SqlParameterExpression { Name: string parameterName } && ParameterValues[parameterName] is null;

// ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null))
//
// a | b | F1 = a == b | F2 = (a != null && b != null) | F3 = F1 && F2 |
Expand Down
Loading

0 comments on commit a07a1bd

Please sign in to comment.