Skip to content

Commit

Permalink
Improve DateTime64 pushdown to ClickHouse
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheikin committed Oct 16, 2024
1 parent c4c68c1 commit 0c80cf5
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.expression.ConnectorExpressionRule.RewriteContext;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.clickhouse.expression.RewriteComparison;
import io.trino.plugin.clickhouse.expression.RewriteLike;
import io.trino.plugin.clickhouse.expression.RewriteStringComparison;
import io.trino.plugin.clickhouse.expression.RewriteStringIn;
import io.trino.plugin.clickhouse.expression.RewriteTimestampConstant;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
Expand Down Expand Up @@ -109,12 +110,16 @@
import java.util.Map.Entry;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.TimeZone;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import static com.clickhouse.data.ClickHouseDataType.DateTime;
import static com.clickhouse.data.ClickHouseDataType.DateTime64;
import static com.clickhouse.data.ClickHouseDataType.FixedString;
import static com.clickhouse.data.ClickHouseValues.convertToQuotedString;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.emptyToNull;
Expand Down Expand Up @@ -242,7 +247,9 @@ public ClickHouseClient(
JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.add(new RewriteStringComparison())
.add(new RewriteTimestampConstant())
.add(new RewriteComparison(ImmutableList.of(CharType.class, VarcharType.class), ImmutableSet.of(FixedString, ClickHouseDataType.String)))
.add(new RewriteComparison(ImmutableList.of(TimestampType.class, TimestampWithTimeZoneType.class), ImmutableSet.of(DateTime, DateTime64)))
.add(new RewriteStringIn())
.add(new RewriteLike())
.map("$not(value: boolean)").to("NOT value")
Expand Down Expand Up @@ -1074,17 +1081,14 @@ private static SliceWriteFunction uuidWriteFunction()
return (statement, index, value) -> statement.setObject(index, trinoUuidToJavaUuid(value), Types.OTHER);
}

public static boolean supportsPushdown(Variable variable, RewriteContext<ParameterizedExpression> context)
public static boolean supportsPushdown(Variable variable, RewriteContext<ParameterizedExpression> context, Set<ClickHouseDataType> nativeTypes)
{
JdbcTypeHandle typeHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName()))
.getJdbcTypeHandle();
String jdbcTypeName = typeHandle.jdbcTypeName()
.orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + typeHandle));
ClickHouseColumn column = ClickHouseColumn.of("", jdbcTypeName);
ClickHouseDataType columnDataType = column.getDataType();
return switch (columnDataType) {
case FixedString, String -> true;
default -> false;
};
return nativeTypes.contains(columnDataType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
*/
package io.trino.plugin.clickhouse.expression;

import com.clickhouse.data.ClickHouseDataType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
Expand All @@ -24,9 +26,8 @@
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.CharType;
import io.trino.spi.type.VarcharType;

import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

Expand All @@ -43,26 +44,34 @@
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.lang.String.format;

public class RewriteStringComparison
public class RewriteComparison
implements ConnectorExpressionRule<Call, ParameterizedExpression>
{
private static final Capture<ConnectorExpression> LEFT = newCapture();
private static final Capture<ConnectorExpression> RIGHT = newCapture();
private static final Pattern<Call> PATTERN = call()
.with(type().equalTo(BOOLEAN))
.with(functionName().matching(Stream.of(ComparisonOperator.values())
.filter(comparison -> comparison != ComparisonOperator.IDENTICAL)
.map(ComparisonOperator::getFunctionName)
.collect(toImmutableSet())
::contains))
.with(argumentCount().equalTo(2))
.with(argument(0).matching(expression().with(type().matching(type -> type instanceof CharType || type instanceof VarcharType)).capturedAs(LEFT)))
.with(argument(1).matching(expression().with(type().matching(type -> type instanceof CharType || type instanceof VarcharType)).capturedAs(RIGHT)));

private final Pattern<Call> pattern;
private final ImmutableSet<ClickHouseDataType> nativeTypes;

public RewriteComparison(List<Class<?>> classes, ImmutableSet<ClickHouseDataType> nativeTypes)
{
pattern = call()
.with(type().equalTo(BOOLEAN))
.with(functionName().matching(Stream.of(ComparisonOperator.values())
.filter(comparison -> comparison != ComparisonOperator.IDENTICAL)
.map(ComparisonOperator::getFunctionName)
.collect(toImmutableSet())
::contains))
.with(argumentCount().equalTo(2))
.with(argument(0).matching(expression().with(type().matching(type -> classes.stream().anyMatch(aClass -> aClass.isInstance(type)))).capturedAs(LEFT)))
.with(argument(1).matching(expression().with(type().matching(type -> classes.stream().anyMatch(aClass -> aClass.isInstance(type)))).capturedAs(RIGHT)));
this.nativeTypes = nativeTypes;
}

@Override
public Pattern<Call> getPattern()
{
return PATTERN;
return pattern;
}

@Override
Expand All @@ -72,11 +81,11 @@ public Optional<ParameterizedExpression> rewrite(Call expression, Captures captu
ConnectorExpression leftExpression = captures.get(LEFT);
ConnectorExpression rightExpression = captures.get(RIGHT);

if (leftExpression instanceof Variable variable && !supportsPushdown(variable, context)) {
if (leftExpression instanceof Variable variable && !supportsPushdown(variable, context, nativeTypes)) {
return Optional.empty();
}

if (rightExpression instanceof Variable variable && !supportsPushdown(variable, context)) {
if (rightExpression instanceof Variable variable && !supportsPushdown(variable, context, nativeTypes)) {
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
*/
package io.trino.plugin.clickhouse.expression;

import com.clickhouse.data.ClickHouseDataType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
Expand All @@ -29,6 +31,7 @@

import java.util.Optional;

import static com.clickhouse.data.ClickHouseDataType.FixedString;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
Expand Down Expand Up @@ -59,7 +62,7 @@ public class RewriteLike
.with(argumentCount().equalTo(2))
.with(argument(0).matching(variable()
.with(type().matching(type -> type instanceof CharType || type instanceof VarcharType))
.matching((Variable variable, RewriteContext<ParameterizedExpression> context) -> supportsPushdown(variable, context))
.matching((Variable variable, RewriteContext<ParameterizedExpression> context) -> supportsPushdown(variable, context, ImmutableSet.of(FixedString, ClickHouseDataType.String)))
.capturedAs(LIKE_VALUE)))
.with(argument(1).matching(constant()
.with(type().matching(type -> type instanceof CharType || type instanceof VarcharType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
*/
package io.trino.plugin.clickhouse.expression;

import com.clickhouse.data.ClickHouseDataType;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
Expand All @@ -31,6 +33,7 @@
import java.util.List;
import java.util.Optional;

import static com.clickhouse.data.ClickHouseDataType.FixedString;
import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
Expand Down Expand Up @@ -59,7 +62,7 @@ public class RewriteStringIn
.with(argumentCount().equalTo(2))
.with(argument(0).matching(variable()
.with(type().matching(type -> type instanceof CharType || type instanceof VarcharType))
.matching((Variable variable, RewriteContext<ParameterizedExpression> context) -> supportsPushdown(variable, context))
.matching((Variable variable, RewriteContext<ParameterizedExpression> context) -> supportsPushdown(variable, context, ImmutableSet.of(FixedString, ClickHouseDataType.String)))
.capturedAs(VALUE)))
.with(argument(1).matching(call()
.with(functionName().equalTo(ARRAY_CONSTRUCTOR_FUNCTION_NAME))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.clickhouse.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;

import java.util.Optional;

import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;

public class RewriteTimestampConstant
implements ConnectorExpressionRule<Constant, ParameterizedExpression>
{
private static final Pattern<Constant> PATTERN = constant().with(type().matching(type -> type instanceof TimestampType || type instanceof TimestampWithTimeZoneType));

@Override
public Pattern<Constant> getPattern()
{
return PATTERN;
}

@Override
public Optional<ParameterizedExpression> rewrite(Constant constant, Captures captures, RewriteContext<ParameterizedExpression> context)
{
Object value = constant.getValue();
if (value == null) {
// TODO we could handle NULL values too
return Optional.empty();
}
return Optional.of(new ParameterizedExpression("?", ImmutableList.of(new QueryParameter(constant.getType(), Optional.of(value)))));
}
}
Loading

0 comments on commit 0c80cf5

Please sign in to comment.