From e20b82ffbfd0e6803c21977779a748413c65e1e3 Mon Sep 17 00:00:00 2001 From: Sasha Sheikin Date: Fri, 27 Sep 2024 11:28:01 +0200 Subject: [PATCH] Improve LIKE pushdown for ClickHouse complex expression --- .../plugin/clickhouse/ClickHouseClient.java | 2 + .../clickhouse/expression/RewriteLike.java | 103 ++++++++++++++++++ .../TestClickHouseConnectorTest.java | 69 +++++++++++- 3 files changed, 172 insertions(+), 2 deletions(-) create mode 100644 plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/expression/RewriteLike.java diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index 4b45c77acb1a7e..662d28ed0669b3 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -30,6 +30,7 @@ import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.expression.ConnectorExpressionRule.RewriteContext; import io.trino.plugin.base.mapping.IdentifierMapping; +import io.trino.plugin.clickhouse.expression.RewriteLike; import io.trino.plugin.clickhouse.expression.RewriteStringComparison; import io.trino.plugin.clickhouse.expression.RewriteStringIn; import io.trino.plugin.jdbc.BaseJdbcClient; @@ -231,6 +232,7 @@ public ClickHouseClient( .addStandardRules(this::quoted) .add(new RewriteStringComparison()) .add(new RewriteStringIn()) + .add(new RewriteLike()) .map("$not(value: boolean)").to("NOT value") .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/expression/RewriteLike.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/expression/RewriteLike.java new file mode 100644 index 00000000000000..f79f693b0a9714 --- /dev/null +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/expression/RewriteLike.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.clickhouse.expression; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.Constant; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.CharType; +import io.trino.spi.type.VarcharType; + +import java.util.Optional; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.plugin.clickhouse.ClickHouseClient.supportsPushdown; +import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; + +public class RewriteLike + implements ConnectorExpressionRule +{ + private static final Capture LIKE_VALUE = newCapture(); + // TODO allow Variable as a LIKE_PATTERN: "SELECT * FROM t WHERE column_a LIKE column_b" is a valid query in ClickHouse + // only Constant is allowed as LIKE_PATTERN, because according to + // https://clickhouse.com/docs/en/sql-reference/functions/string-search-functions#like + // ClickHouse requires backslashes in strings to be quoted as well, so you would actually need to write \\%, \\_ and \\\\ to match against literal %, _ and \ + // if "column_a LIKE column_b" is pushed down, it requires more thorough consideration how to process escaping. + private static final Capture LIKE_PATTERN = newCapture(); + private static final Pattern PATTERN = call() + .with(functionName().equalTo(LIKE_FUNCTION_NAME)) + .with(type().equalTo(BOOLEAN)) + .with(argumentCount().equalTo(2)) + .with(argument(0).matching(variable() + .with(type().matching(type -> type instanceof CharType || type instanceof VarcharType)) + .matching((Variable variable, RewriteContext context) -> supportsPushdown(variable, context)) + .capturedAs(LIKE_VALUE))) + .with(argument(1).matching(constant() + .with(type().matching(type -> type instanceof CharType || type instanceof VarcharType)) + .capturedAs(LIKE_PATTERN))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(Call expression, Captures captures, RewriteContext context) + { + Optional value = context.defaultRewrite(captures.get(LIKE_VALUE)); + if (value.isEmpty()) { + return Optional.empty(); + } + Optional pattern = context.defaultRewrite(captures.get(LIKE_PATTERN)); + if (pattern.isEmpty()) { + return Optional.empty(); + } + + // Capture LIKE_PATTERN guarantees that value is a single varchar + QueryParameter patternParameter = pattern.get().parameters().getFirst(); + Slice slice = (Slice) patternParameter.getValue().orElseThrow(); + // ClickHouse requires backslashes in strings to be quoted as well, so you would actually need to write \\%, \\_ and \\\\ to match against literal %, _ and \ + String patternValue = new String(slice.byteArray(), UTF_8); + if (patternValue.contains("\\")) { + // TODO escape `\` appropriately and pushdown: .replace("\\", "\\\\\\\\") + return Optional.empty(); + } + + return Optional.of(new ParameterizedExpression( + format("%s LIKE %s", value.get().expression(), pattern.get().expression()), + ImmutableList.builder() + .addAll(value.get().parameters()) + .addAll(pattern.get().parameters()) + .build())); + } +} diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index 5ec5838422bed4..49572b1fabf98c 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -66,6 +66,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_AGGREGATION_PUSHDOWN_COVARIANCE, SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT, SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION, + SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE, SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY, SUPPORTS_TOPN_PUSHDOWN, SUPPORTS_TRUNCATE -> true; @@ -959,10 +960,13 @@ a_nullable_string_alias Nullable(Text), a_nullable_fixed_string Nullable(FixedString(1)), a_lowcardinality_nullable_string LowCardinality(Nullable(String)), a_lowcardinality_nullable_fixed_string LowCardinality(Nullable(FixedString(1))), - a_enum_1 Enum('hello', 'world', 'a', 'b', 'c'), - a_enum_2 Enum('hello', 'world', 'a', 'b', 'c')) + a_enum_1 Enum('hello', 'world', 'a', 'b', 'c', '%', '_'), + a_enum_2 Enum('hello', 'world', 'a', 'b', 'c', '%', '_')) ENGINE=Log""", List.of( + "(10, 10), (10, 10), 'z', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', '\\\\', 'hello', 'world'", + "(10, 10), (10, 10), 'z', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_'", + "(10, 10), (10, 10), 'z', '%', '%', '%', '%', '%', '%', '%', '%', '%', '%'", "(10, 10), (10, 10), 'z', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a'", "(10, 10), (10, 10), 'z', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b'", "(10, 10), (10, 10), 'z', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c'"))) { @@ -1018,9 +1022,70 @@ a_enum_2 Enum('hello', 'world', 'a', 'b', 'c')) assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string NOT IN ('a', 'b')" + withConnectorExpression)).isFullyPushedDown(); assertThat(query(smallDomainCompactionThreshold, "SELECT some_column FROM " + table.getName() + " WHERE a_string NOT IN ('a', 'b')")).isNotFullyPushedDown(FilterNode.class); assertThat(query(smallDomainCompactionThreshold, "SELECT some_column FROM " + table.getName() + " WHERE a_string NOT IN ('a', 'b')" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + + assertLike(true, table, withConnectorExpression, convertToVarchar); + assertLike(false, table, withConnectorExpression, convertToVarchar); } } + private void assertLike(boolean isPositive, TestTable table, String withConnectorExpression, Session convertToVarchar) + { + String like = isPositive ? "LIKE" : "NOT LIKE"; + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " NULL")).returnsEmptyResult(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b%'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " 'b%'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b%'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%b%'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 " + like + " '%b%'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_enum_1 " + like + " '%b%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE unsupported_1 " + like + " '%b%'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE unsupported_1 " + like + " '%b%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_string_alias")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_string_alias" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_enum_1")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " a_enum_1" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " unsupported_1")).isNotFullyPushedDown(FilterNode.class); + assertThat(query(convertToVarchar, "SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " unsupported_1" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + // metacharacters + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '_'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '_'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '__'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '__'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%'" + withConnectorExpression)).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%%'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%%'" + withConnectorExpression)).isFullyPushedDown(); + // escape + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\b'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\b'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\_'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\_'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\__'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\__'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%%'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%%'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\\\'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\\\\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\' ESCAPE '\\'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\\\' ESCAPE '\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%' ESCAPE '\\'")).isFullyPushedDown(); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '\\%' ESCAPE '\\'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%$_%' ESCAPE '$'")).isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT some_column FROM " + table.getName() + " WHERE a_string " + like + " '%$_%' ESCAPE '$'" + withConnectorExpression)).isNotFullyPushedDown(FilterNode.class); + } + @Test @Override // Override because ClickHouse doesn't follow SQL standard syntax public void testExecuteProcedure()