Skip to content

Commit

Permalink
Improve LIKE pushdown for ClickHouse complex expression
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheikin committed Sep 30, 2024
1 parent 44178ea commit e20b82f
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.expression.ConnectorExpressionRule.RewriteContext;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.clickhouse.expression.RewriteLike;
import io.trino.plugin.clickhouse.expression.RewriteStringComparison;
import io.trino.plugin.clickhouse.expression.RewriteStringIn;
import io.trino.plugin.jdbc.BaseJdbcClient;
Expand Down Expand Up @@ -231,6 +232,7 @@ public ClickHouseClient(
.addStandardRules(this::quoted)
.add(new RewriteStringComparison())
.add(new RewriteStringIn())
.add(new RewriteLike())
.map("$not(value: boolean)").to("NOT value")
.build();
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.clickhouse.expression;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.CharType;
import io.trino.spi.type.VarcharType;

import java.util.Optional;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable;
import static io.trino.plugin.clickhouse.ClickHouseClient.supportsPushdown;
import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;

public class RewriteLike
implements ConnectorExpressionRule<Call, ParameterizedExpression>
{
private static final Capture<Variable> LIKE_VALUE = newCapture();
// TODO allow Variable as a LIKE_PATTERN: "SELECT * FROM t WHERE column_a LIKE column_b" is a valid query in ClickHouse
// only Constant is allowed as LIKE_PATTERN, because according to
// https://clickhouse.com/docs/en/sql-reference/functions/string-search-functions#like
// ClickHouse requires backslashes in strings to be quoted as well, so you would actually need to write \\%, \\_ and \\\\ to match against literal %, _ and \
// if "column_a LIKE column_b" is pushed down, it requires more thorough consideration how to process escaping.
private static final Capture<Constant> LIKE_PATTERN = newCapture();
private static final Pattern<Call> PATTERN = call()
.with(functionName().equalTo(LIKE_FUNCTION_NAME))
.with(type().equalTo(BOOLEAN))
.with(argumentCount().equalTo(2))
.with(argument(0).matching(variable()
.with(type().matching(type -> type instanceof CharType || type instanceof VarcharType))
.matching((Variable variable, RewriteContext<ParameterizedExpression> context) -> supportsPushdown(variable, context))
.capturedAs(LIKE_VALUE)))
.with(argument(1).matching(constant()
.with(type().matching(type -> type instanceof CharType || type instanceof VarcharType))
.capturedAs(LIKE_PATTERN)));

@Override
public Pattern<Call> getPattern()
{
return PATTERN;
}

@Override
public Optional<ParameterizedExpression> rewrite(Call expression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
Optional<ParameterizedExpression> value = context.defaultRewrite(captures.get(LIKE_VALUE));
if (value.isEmpty()) {
return Optional.empty();
}
Optional<ParameterizedExpression> pattern = context.defaultRewrite(captures.get(LIKE_PATTERN));
if (pattern.isEmpty()) {
return Optional.empty();
}

// Capture<Constant> LIKE_PATTERN guarantees that value is a single varchar
QueryParameter patternParameter = pattern.get().parameters().getFirst();
Slice slice = (Slice) patternParameter.getValue().orElseThrow();
// ClickHouse requires backslashes in strings to be quoted as well, so you would actually need to write \\%, \\_ and \\\\ to match against literal %, _ and \
String patternValue = new String(slice.byteArray(), UTF_8);
if (patternValue.contains("\\")) {
// TODO escape `\` appropriately and pushdown: .replace("\\", "\\\\\\\\")
return Optional.empty();
}

return Optional.of(new ParameterizedExpression(
format("%s LIKE %s", value.get().expression(), pattern.get().expression()),
ImmutableList.<QueryParameter>builder()
.addAll(value.get().parameters())
.addAll(pattern.get().parameters())
.build()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
case SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE,
SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT,
SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION,
SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE,
SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY,
SUPPORTS_TOPN_PUSHDOWN,
SUPPORTS_TRUNCATE -> true;
Expand Down Expand Up @@ -959,10 +960,13 @@ a_nullable_string_alias Nullable(Text),
a_nullable_fixed_string Nullable(FixedString(1)),
a_lowcardinality_nullable_string LowCardinality(Nullable(String)),
a_lowcardinality_nullable_fixed_string LowCardinality(Nullable(FixedString(1))),
a_enum_1 Enum('hello', 'world', 'a', 'b', 'c'),
a_enum_2 Enum('hello', 'world', 'a', 'b', 'c'))
a_enum_1 Enum('hello', 'world', 'a', 'b', 'c', '%', '_'),
a_enum_2 Enum('hello', 'world', 'a', 'b', 'c', '%', '_'))
ENGINE=Log""",
List.of(
"(10, 10), (10, 10), 'z', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', 'hello', 'world'",
"(10, 10), (10, 10), 'z', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_'",
"(10, 10), (10, 10), 'z', '%', '%', '%', '%', '%', '%', '%', '%', '%', '%'",
"(10, 10), (10, 10), 'z', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a'",
"(10, 10), (10, 10), 'z', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b'",
"(10, 10), (10, 10), 'z', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c'"))) {
Expand Down Expand Up @@ -1018,9 +1022,70 @@ a_enum_2 Enum('hello', 'world', 'a', 'b', 'c'))
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string NOT IN ('a', 'b')" + withConnectorExpression)).isFullyPushedDown();
assertThat(query(smallDomainCompactionThreshold, "SELECT some_column FROM " + table.getName() + " WHERE a_string NOT IN ('a', 'b')")).isNotFullyPushedDown(FilterNode.class);
assertThat(query(smallDomainCompactionThreshold, "SELECT some_column FROM " + table.getName() + " WHERE a_string NOT IN ('a', 'b')" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);

assertLike(true, table, withConnectorExpression, convertToVarchar);
assertLike(false, table, withConnectorExpression, convertToVarchar);
}
}

private void assertLike(boolean isPositive, TestTable table, String withConnectorExpression, Session convertToVarchar)
{
String like = isPositive ? "LIKE" : "NOT LIKE";
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " NULL")).returnsEmptyResult();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b%'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b%'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b%'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b%'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 " + like + " '%b%'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 " + like + " '%b%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE unsupported_1 " + like + " '%b%'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE unsupported_1 " + like + " '%b%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_string_alias")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_string_alias" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_enum_1")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_enum_1" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " unsupported_1")).isNotFullyPushedDown(FilterNode.class);
assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " unsupported_1" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
// metacharacters
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '_'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '_'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '__'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '__'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%'" + withConnectorExpression)).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%%'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%%'" + withConnectorExpression)).isFullyPushedDown();
// escape
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\b'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\b'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\_'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\_'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\__'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\__'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%%'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\\\'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\' ESCAPE '\\'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\' ESCAPE '\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%' ESCAPE '\\'")).isFullyPushedDown();
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%' ESCAPE '\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%$_%' ESCAPE '$'")).isNotFullyPushedDown(FilterNode.class);
assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%$_%' ESCAPE '$'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class);
}

@Test
@Override // Override because ClickHouse doesn't follow SQL standard syntax
public void testExecuteProcedure()
Expand Down

0 comments on commit e20b82f

Please sign in to comment.