Skip to content

Commit

Permalink
Fix LIKE with char
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi authored and losipiuk committed Jun 25, 2021
1 parent 53bc515 commit 522a48f
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -886,11 +886,10 @@ protected Type visitLikePredicate(LikePredicate node, StackableAstVisitorContext
}

Type patternType = process(node.getPattern(), context);
if (!(valueType instanceof VarcharType)) {
if (!(patternType instanceof VarcharType)) {
// TODO can pattern be of char type?
coerceType(context, node.getPattern(), VARCHAR, "Pattern for LIKE expression");
}
coerceType(context, node.getPattern(), patternType, "Pattern for LIKE expression");
if (node.getEscape().isPresent()) {
Expression escape = node.getEscape().get();
Type escapeType = process(escape, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH;
Expand All @@ -129,6 +130,7 @@
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
import static io.trino.spi.type.Chars.trimTrailingSpaces;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.spi.type.TypeUtils.writeNativeValue;
import static io.trino.spi.type.VarcharType.createVarcharType;
Expand Down Expand Up @@ -1130,22 +1132,30 @@ protected Object visitLikePredicate(LikePredicate node, Object context)

// if pattern is a constant without % or _ replace with a comparison
if (pattern instanceof Slice && (escape == null || escape instanceof Slice) && !isLikePattern((Slice) pattern, Optional.ofNullable((Slice) escape))) {
Slice unescapedPattern = unescapeLiteralLikePattern((Slice) pattern, Optional.ofNullable((Slice) escape));
Type valueType = type(node.getValue());
Type patternType = createVarcharType(unescapedPattern.length());
Type superType = typeCoercion.getCommonSuperType(valueType, patternType)
.orElseThrow(() -> new IllegalArgumentException("Missing super type when optimizing " + node));
Expression valueExpression = toExpression(value, valueType);
if (!valueType.equals(superType)) {
valueExpression = new Cast(valueExpression, toSqlType(superType), false, typeCoercion.isTypeOnlyCoercion(valueType, superType));
}
Slice unescapedPattern = unescapeLiteralLikePattern((Slice) pattern, Optional.ofNullable((Slice) escape));
VarcharType patternType = createVarcharType(countCodePoints(unescapedPattern));

Expression valueExpression;
Expression patternExpression;
if (superType instanceof VarcharType) {
if (valueType instanceof CharType) {
if (((CharType) valueType).getLength() != patternType.getBoundedLength()) {
return false;
}
valueExpression = toExpression(value, valueType);
patternExpression = toExpression(trimTrailingSpaces(unescapedPattern), valueType);
}
else if (valueType instanceof VarcharType) {
Type superType = typeCoercion.getCommonSuperType(valueType, patternType)
.orElseThrow(() -> new IllegalArgumentException("Missing super type when optimizing " + node));
valueExpression = toExpression(value, valueType);
if (!valueType.equals(superType)) {
valueExpression = new Cast(valueExpression, toSqlType(superType), false, typeCoercion.isTypeOnlyCoercion(valueType, superType));
}
patternExpression = toExpression(unescapedPattern, superType);
}
else {
patternExpression = toExpression(unescapedPattern, patternType);
patternExpression = new Cast(patternExpression, toSqlType(superType), false, typeCoercion.isTypeOnlyCoercion(patternType, superType));
throw new IllegalStateException("Unsupported valueType for LIKE: " + valueType);
}
return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, valueExpression, patternExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.FeaturesConfig;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;

Expand Down Expand Up @@ -88,7 +89,7 @@ public final void destroyTestFunctions()
functionAssertions = null;
}

protected void assertFunction(String projection, Type expectedType, Object expected)
protected void assertFunction(@Language("SQL") String projection, Type expectedType, Object expected)
{
functionAssertions.assertFunction(projection, expectedType, expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.CharType.createCharType;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.DecimalType.createDecimalType;
import static io.trino.spi.type.DoubleType.DOUBLE;
Expand Down Expand Up @@ -85,6 +86,7 @@

public class TestExpressionInterpreter
{
private static final int TEST_CHAR_TYPE_LENGTH = 17;
private static final int TEST_VARCHAR_TYPE_LENGTH = 17;
private static final TypeProvider SYMBOL_TYPES = TypeProvider.copyOf(ImmutableMap.<Symbol, Type>builder()
.put(new Symbol("bound_integer"), INTEGER)
Expand All @@ -105,6 +107,7 @@ public class TestExpressionInterpreter
.put(new Symbol("unbound_long"), BIGINT)
.put(new Symbol("unbound_long2"), BIGINT)
.put(new Symbol("unbound_long3"), BIGINT)
.put(new Symbol("unbound_char"), createCharType(TEST_CHAR_TYPE_LENGTH))
.put(new Symbol("unbound_string"), VARCHAR)
.put(new Symbol("unbound_double"), DOUBLE)
.put(new Symbol("unbound_boolean"), BOOLEAN)
Expand Down Expand Up @@ -1376,6 +1379,27 @@ public void testLike()
assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true");
}

@Test
public void testLikeChar()
{
assertOptimizedEquals("CAST('abc' AS char(3)) LIKE 'abc'", "true");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE 'abc'", "false");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE 'abc '", "true");

assertOptimizedEquals("CAST('abc' AS char(3)) LIKE '%abc'", "true");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%abc'", "false");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%abc '", "true");

assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%c'", "false");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%c '", "true");

assertOptimizedEquals("CAST('abc' AS char(3)) LIKE '%a%b%c'", "true");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%a%b%c'", "false");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%a%b%c '", "true");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%a%b%c_'", "true");
assertOptimizedEquals("CAST('abc' AS char(4)) LIKE '%a%b%c%'", "true");
}

@Test
public void testLikeOptimization()
{
Expand All @@ -1397,6 +1421,61 @@ public void testLikeOptimization()
assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string");
}

@Test
public void testLikeCharOptimization()
{
// constant literal pattern of length shorter than value length
assertOptimizedEquals("unbound_char LIKE 'abc'", "false");
assertOptimizedEquals("unbound_char LIKE 'abc' ESCAPE '#'", "false");
assertOptimizedEquals("unbound_char LIKE 'ab#_' ESCAPE '#'", "false");
assertOptimizedEquals("unbound_char LIKE 'ab#%' ESCAPE '#'", "false");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'abc'", "false");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'abc' ESCAPE '#'", "false");

// constant non-literal pattern of length shorter than value length
assertOptimizedEquals("unbound_char LIKE 'ab_'", "unbound_char LIKE 'ab_'");
assertOptimizedEquals("unbound_char LIKE 'ab%'", "unbound_char LIKE 'ab%'");
assertOptimizedEquals("unbound_char LIKE 'ab%' ESCAPE '#'", "unbound_char LIKE 'ab%' ESCAPE '#'");

// constant literal pattern of length equal to value length
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'abcd'", "CAST(unbound_char AS char(4)) = CAST('abcd' AS char(4))");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'abcd' ESCAPE '#'", "CAST(unbound_char AS char(4)) = CAST('abcd' AS char(4))");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'jaźń'", "CAST(unbound_char AS char(4)) = CAST('jaźń' AS char(4))");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'ab#_' ESCAPE '#'", "false");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'ab#%' ESCAPE '#'", "false");

// constant non-literal pattern of length equal to value length
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'ab#%' ESCAPE '\\'", "CAST(unbound_char AS char(4)) LIKE 'ab#%' ESCAPE '\\'");

// constant pattern of length longer than value length
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'abcde'", "false");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE 'abcde' ESCAPE '#'", "false");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE '#%a#%b#%c#%d#%' ESCAPE '#'", "false");

// constant non-literal pattern of length longer than value length
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE '%a%b%c%d%'", "CAST(unbound_char AS char(4)) LIKE '%a%b%c%d%'");
assertOptimizedEquals("CAST(unbound_char AS char(4)) LIKE '%a%b%c%d%' ESCAPE '#'", "CAST(unbound_char AS char(4)) LIKE '%a%b%c%d%' ESCAPE '#'");

// without explicit CAST on value, constant pattern of equal length
assertOptimizedEquals(
"unbound_char LIKE CAST(CAST('abc' AS char( " + TEST_CHAR_TYPE_LENGTH + ")) AS varchar(" + TEST_CHAR_TYPE_LENGTH + "))",
"unbound_char = CAST('abc' AS char(17))");
assertOptimizedEquals(
"unbound_char LIKE CAST(CAST('abc' AS char( " + TEST_CHAR_TYPE_LENGTH + ")) AS varchar)",
"unbound_char = CAST('abc' AS char(17))");

assertOptimizedEquals(
"unbound_char LIKE CAST(CAST('' AS char(" + TEST_CHAR_TYPE_LENGTH + ")) AS varchar(" + TEST_CHAR_TYPE_LENGTH + ")) ESCAPE '#'",
"unbound_char LIKE CAST(' ' AS varchar(17)) ESCAPE '#'");
assertOptimizedEquals(
"unbound_char LIKE CAST(CAST('' AS char(" + TEST_CHAR_TYPE_LENGTH + ")) AS varchar) ESCAPE '#'",
"unbound_char LIKE CAST(' ' AS varchar(17)) ESCAPE '#'");

assertOptimizedEquals("unbound_char LIKE bound_pattern", "unbound_char LIKE VARCHAR '%el%'");
assertOptimizedEquals("unbound_char LIKE unbound_pattern", "unbound_char LIKE unbound_pattern");
assertOptimizedEquals("unbound_char LIKE unbound_pattern ESCAPE unbound_string", "unbound_char LIKE unbound_pattern ESCAPE unbound_string");
}

@Test
public void testOptimizeInvalidLike()
{
Expand Down
43 changes: 41 additions & 2 deletions core/trino-main/src/test/java/io/trino/sql/TestLikeFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ public void testLikeBasic()

assertFunction("'foob' LIKE 'f%b__'", BOOLEAN, false);
assertFunction("'foob' LIKE 'f%b'", BOOLEAN, true);

// value with explicit type (formal type potentially longer than actual length)
assertFunction("CAST('foo' AS varchar(6)) LIKE 'foo '", BOOLEAN, false);
assertFunction("CAST('foo ' AS varchar(6)) LIKE 'foo '", BOOLEAN, true);
assertFunction("CAST('foo' AS varchar(6)) LIKE 'foo___'", BOOLEAN, false);
assertFunction("CAST('foo' AS varchar(6)) LIKE 'foo%'", BOOLEAN, true);

// value and pattern with explicit type (formal type potentially longer than actual length)
assertFunction("CAST('foo' AS varchar(6)) LIKE CAST('foo' AS varchar(6))", BOOLEAN, true);
assertFunction("CAST('foo' AS varchar(6)) LIKE CAST('foo ' AS varchar(3))", BOOLEAN, true); // pattern gets truncated
assertFunction("CAST('foo' AS varchar(6)) LIKE CAST('foo ' AS varchar(6))", BOOLEAN, false);
}

@Test
Expand All @@ -69,8 +80,36 @@ public void testLikeChar()
assertFalse(likeChar(7L, utf8Slice("foob"), regex));
assertFalse(likeChar(7L, offsetHeapSlice("foob"), regex));

assertFunction("cast('foob' as char(6)) LIKE 'f%b__'", BOOLEAN, true);
assertFunction("cast('foob' as char(7)) LIKE 'f%b__'", BOOLEAN, false);
// pattern shorter than value length
assertFunction("CAST('foo' AS char(6)) LIKE 'foo'", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo '", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE 'fo_'", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE 'fo%'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE '%foo'", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE '_oo'", BOOLEAN, false);
assertFunction("CAST('foob' AS char(6)) LIKE 'f%b__'", BOOLEAN, true);
assertFunction("CAST('foob' AS char(7)) LIKE 'f%b__'", BOOLEAN, false);

// pattern of length equal to value length
assertFunction("CAST('foo' AS char(3)) LIKE 'foo'", BOOLEAN, true);
assertFunction("CAST('jaźń' AS char(4)) LIKE 'jaźń'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(3)) LIKE 'fob'", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo '", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo __'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE '%%%%%%'", BOOLEAN, true);

// pattern longer than value length
assertFunction("CAST('foo' AS char(3)) LIKE '%%foo'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(3)) LIKE 'f#_#_' ESCAPE '#'", BOOLEAN, false);
assertFunction("CAST('f__' AS char(3)) LIKE 'f#_#_' ESCAPE '#'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo '", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo __ '", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE '_______'", BOOLEAN, false);
assertFunction("CAST('foo' AS char(6)) LIKE '%%%%%%%'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo %%%%%%%'", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo %%%%%%% '", BOOLEAN, true);
assertFunction("CAST('foo' AS char(6)) LIKE 'foo %%%%%%% '", BOOLEAN, false);
assertFunction("CAST('foobar' AS char(6)) LIKE 'foobar%%%%%%%'", BOOLEAN, true);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2719,6 +2719,12 @@ public void testExpressions()
.hasErrorCode(TYPE_MISMATCH);
assertFails("SELECT 'a' LIKE 'b' ESCAPE 1 FROM t1")
.hasErrorCode(TYPE_MISMATCH);
assertFails("SELECT 'abc' LIKE CHAR 'abc' FROM t1")
.hasErrorCode(TYPE_MISMATCH)
.hasMessage("line 1:19: Pattern for LIKE expression must evaluate to a varchar (actual: char(3))");
assertFails("SELECT 'abc' LIKE 'abc' ESCAPE CHAR '#' FROM t1")
.hasErrorCode(TYPE_MISMATCH)
.hasMessage("line 1:32: Escape for LIKE expression must evaluate to a varchar (actual: char(1))");

// extract
assertFails("SELECT EXTRACt(DAY FROM 'a') FROM t1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ WHERE ("cr_call_center_sk" = "cc_call_center_sk")
AND ("cd_education_status" = 'Unknown'))
OR (("cd_marital_status" = 'W')
AND ("cd_education_status" = 'Advanced Degree')))
AND ("hd_buy_potential" LIKE 'Unknown')
AND ("hd_buy_potential" LIKE 'Unknown ')
AND ("ca_gmt_offset" = -7)
GROUP BY "cc_call_center_id", "cc_name", "cc_manager", "cd_marital_status", "cd_education_status"
ORDER BY "sum"("cr_net_loss") DESC
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WHERE ("cr_call_center_sk" = "cc_call_center_sk")
AND ("cd_education_status" = 'Unknown'))
OR (("cd_marital_status" = 'W')
AND ("cd_education_status" = 'Advanced Degree')))
AND ("hd_buy_potential" LIKE 'Unknown')
AND ("hd_buy_potential" LIKE 'Unknown ')
AND ("ca_gmt_offset" = -7)
GROUP BY "cc_call_center_id", "cc_name", "cc_manager", "cd_marital_status", "cd_education_status"
ORDER BY "sum"("cr_net_loss") DESC

0 comments on commit 522a48f

Please sign in to comment.