Skip to content

Commit

Permalink
Correctly escape escape character in String.Contains(), `String.Sta…
Browse files Browse the repository at this point in the history
…rtsWith()` and `String.EndsWith()` translations if pattern doesn't contain wildcards.
  • Loading branch information
lauxjpn committed Dec 8, 2024
1 parent fde6828 commit 8bbd78a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 11 deletions.
10 changes: 6 additions & 4 deletions src/EFCore.MySql/Extensions/MySqlDbFunctionsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,9 @@ public static int DateDiffNanosecond(
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="matchExpression">The property of entity that is to be matched.</param>
/// <param name="pattern">The pattern which may involve wildcards %,_,[,],^.</param>
/// <param name="pattern">
/// The pattern which may involve the wildcards `%` and `_`. Use the character `\` to escape wildcards and itself.
/// </param>
/// <returns>true if there is a match.</returns>
public static bool Like<T>(
[CanBeNull] this DbFunctions _,
Expand All @@ -1261,10 +1263,10 @@ public static bool Like<T>(
/// </summary>
/// <param name="_">The DbFunctions instance.</param>
/// <param name="matchExpression">The property of entity that is to be matched.</param>
/// <param name="pattern">The pattern which may involve wildcards %,_,[,],^.</param>
/// <param name="pattern">The pattern which may involve the wildcards `%` and `_`.</param>
/// <param name="escapeCharacter">
/// The escape character (as a single character string) to use in front of %,_,[,],^
/// if they are not used as wildcards.
/// The escape character (as a single character string) to use in front of `%` and `_` (if they are not used as wildcards), and
/// itself.
/// </param>
/// <returns>true if there is a match.</returns>
public static bool Like<T>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ private SqlExpression MakeStartsWithEndsWithExpressionImpl(
targetTransform(target),
prefixSuffixTransform(
_sqlExpressionFactory.Constant(
$"{(startsWith ? string.Empty : "%")}{(s.Any(IsLikeWildChar) ? EscapeLikePattern(s) : s)}{(startsWith ? "%" : string.Empty)}"))),
$"{(startsWith ? string.Empty : "%")}{(s.Any(IsLikeWildOrEscapeChar) ? EscapeLikePattern(s) : s)}{(startsWith ? "%" : string.Empty)}"))),
_ => throw new UnreachableException(),
};
}
Expand Down Expand Up @@ -466,8 +466,6 @@ private SqlExpression MakeContainsExpressionImpl(

if (pattern is SqlConstantExpression constantPatternExpression)
{
// The prefix is constant. Aside from null or empty, we escape all special characters (%, _, \)
// in C# and send a simple LIKE.
// The prefix is constant. Aside from null or empty, we escape all special characters (%, _, \)
// in C# and send a simple LIKE.
return constantPatternExpression.Value switch
Expand All @@ -476,7 +474,7 @@ private SqlExpression MakeContainsExpressionImpl(
"" => _sqlExpressionFactory.Like(targetTransform(target), _sqlExpressionFactory.Constant("%")),
string s => _sqlExpressionFactory.Like(
targetTransform(target),
patternTransform(_sqlExpressionFactory.Constant($"%{(s.Any(IsLikeWildChar) ? EscapeLikePattern(s) : s)}%"))),
patternTransform(_sqlExpressionFactory.Constant($"%{(s.Any(IsLikeWildOrEscapeChar) ? EscapeLikePattern(s) : s)}%"))),
_ => throw new UnreachableException(),
};
}
Expand Down Expand Up @@ -697,15 +695,15 @@ private SqlExpression Locate(SqlExpression sub, SqlExpression str, SqlExpression

private const char LikeEscapeChar = '\\';

private static bool IsLikeWildChar(char c) => c == '%' || c == '_';
private static bool IsLikeWildOrEscapeChar(char c) => IsLikeWildChar(c) || LikeEscapeChar == c;
private static bool IsLikeWildChar(char c) => c is '%' or '_';

private static string EscapeLikePattern(string pattern)
{
var builder = new StringBuilder();
foreach (var c in pattern)
{
if (IsLikeWildChar(c) ||
c == LikeEscapeChar)
if (IsLikeWildOrEscapeChar(c))
{
builder.Append(LikeEscapeChar);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,101 @@ public virtual void Radians()
WHERE `c`.`CustomerID` = 'VINET'
LIMIT 1");
}

[ConditionalFact]
public virtual void Contains_with_escape_char()
{
using var context = CreateContext();
var count = context.Customers.Count(c => c.CompanyName.Replace("/", @"\").Contains(@"\"));

Assert.Equal(1, count);

AssertSql(
"""
SELECT COUNT(*)
FROM `Customers` AS `c`
WHERE REPLACE(`c`.`CompanyName`, '/', '\\') LIKE '%\\\\%'
""");
}

[ConditionalFact]
public virtual void Contains_with_wild_char()
{
using var context = CreateContext();
var count = context.Customers.Count(c => c.CompanyName.Replace("/", "%").Contains("%"));

Assert.Equal(1, count);

AssertSql(
"""
SELECT COUNT(*)
FROM `Customers` AS `c`
WHERE REPLACE(`c`.`CompanyName`, '/', '%') LIKE '%\\%%'
""");
}

[ConditionalFact]
public virtual void StartsWith_with_escape_char()
{
using var context = CreateContext();
var count = context.Customers.Count(c => c.CompanyName.Replace("A", @"\").StartsWith(@"\"));

Assert.Equal(4, count);

AssertSql(
"""
SELECT COUNT(*)
FROM `Customers` AS `c`
WHERE REPLACE(`c`.`CompanyName`, 'A', '\\') LIKE '\\\\%'
""");
}

[ConditionalFact]
public virtual void StartsWith_with_wild_char()
{
using var context = CreateContext();
var count = context.Customers.Count(c => c.CompanyName.Replace("A", @"%").StartsWith(@"%"));

Assert.Equal(4, count);

AssertSql(
"""
SELECT COUNT(*)
FROM `Customers` AS `c`
WHERE REPLACE(`c`.`CompanyName`, 'A', '%') LIKE '\\%%'
""");
}

[ConditionalFact]
public virtual void EndsWith_with_escape_char()
{
using var context = CreateContext();
var count = context.Customers.Count(c => c.CompanyName.Replace("a", @"\").EndsWith(@"\"));

Assert.Equal(7, count);

AssertSql(
"""
SELECT COUNT(*)
FROM `Customers` AS `c`
WHERE REPLACE(`c`.`CompanyName`, 'a', '\\') LIKE '%\\\\'
""");
}

[ConditionalFact]
public virtual void EndsWith_with_wild_char()
{
using var context = CreateContext();
var count = context.Customers.Count(c => c.CompanyName.Replace("a", @"%").EndsWith(@"%"));

Assert.Equal(7, count);

AssertSql(
"""
SELECT COUNT(*)
FROM `Customers` AS `c`
WHERE REPLACE(`c`.`CompanyName`, 'a', '%') LIKE '%\\%'
""");
}
}
}

0 comments on commit 8bbd78a

Please sign in to comment.