Skip to content

Commit

Permalink
Enable decimal full predicate pushdown for ClickHouse connector
Browse files Browse the repository at this point in the history
  • Loading branch information
sylph-eu committed Apr 22, 2024
1 parent 88618fe commit 616f9cb
Show file tree
Hide file tree
Showing 4 changed files with 451 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.SliceWriteFunction;
import io.trino.plugin.jdbc.WriteFunction;
import io.trino.plugin.jdbc.WriteMapping;
import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint;
import io.trino.plugin.jdbc.aggregation.ImplementCount;
Expand Down Expand Up @@ -77,6 +78,7 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import java.math.RoundingMode;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.sql.Connection;
Expand Down Expand Up @@ -662,24 +664,24 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
case Types.DOUBLE:
return Optional.of(doubleColumnMapping());

case Types.NUMERIC:
case Types.DECIMAL:
int decimalDigits = typeHandle.requiredDecimalDigits();
int precision = typeHandle.requiredColumnSize();

ColumnMapping decimalColumnMapping;
if (getDecimalRounding(session) == ALLOW_OVERFLOW && precision > Decimals.MAX_PRECISION) {
int scale = Math.min(decimalDigits, getDecimalDefaultScale(session));
decimalColumnMapping = decimalColumnMapping(createDecimalType(Decimals.MAX_PRECISION, scale), getDecimalRoundingMode(session));
return Optional.of(customDecimalColumnMapping(createDecimalType(Decimals.MAX_PRECISION, scale), getDecimalRoundingMode(session)));
}
else {
decimalColumnMapping = decimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0)));

// At the moment ClickHouse doesn't support negative scales and scales
// larger than precision, handle it nevertheless.
precision = precision + max(-decimalDigits, 0); // Map decimal(p, -s) (negative scale) to decimal(p+s, 0).
if (precision > Decimals.MAX_PRECISION) {
break; // Ignore column
}
return Optional.of(ColumnMapping.mapping(
decimalColumnMapping.getType(),
decimalColumnMapping.getReadFunction(),
decimalColumnMapping.getWriteFunction(),
// TODO (https://github.com/trinodb/trino/issues/7100) fix, enable and test decimal pushdown
DISABLE_PUSHDOWN));

return Optional.of(customDecimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0)), UNNECESSARY));

case Types.DATE:
return Optional.of(dateColumnMappingUsingLocalDate(getClickHouseServerVersion(session)));
Expand Down Expand Up @@ -894,6 +896,78 @@ private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction()
};
}

private ColumnMapping customDecimalColumnMapping(DecimalType decimalType, RoundingMode roundingMode)
{
ColumnMapping nativeMapping = decimalColumnMapping(decimalType, roundingMode);
return ColumnMapping.mapping(
nativeMapping.getType(),
nativeMapping.getReadFunction(),
customDecimalWriteFunction(decimalType, nativeMapping.getWriteFunction()),
FULL_PUSHDOWN);
}

private WriteFunction customDecimalWriteFunction(DecimalType decimalType, WriteFunction writeFunction)
{
checkArgument(writeFunction instanceof LongWriteFunction || writeFunction instanceof ObjectWriteFunction,
"writeFunction must be LongWriteFunction or ObjectWriteFunction");
if (writeFunction instanceof LongWriteFunction) {
LongWriteFunction longWriteFunction = (LongWriteFunction) writeFunction;
return new LongWriteFunction() {
@Override
public String getBindExpression()
{
// Syntax to force ClickHouse use Decimal parsing of floats
return format("?::Decimal(%d, %d)", decimalType.getPrecision(), decimalType.getScale());
}

@Override
public void set(PreparedStatement statement, int index, long value)
throws SQLException
{
longWriteFunction.set(statement, index, value);
}

@Override
public void setNull(PreparedStatement statement, int index)
throws SQLException
{
longWriteFunction.setNull(statement, index);
}
};
}

ObjectWriteFunction objectWriteFunction = (ObjectWriteFunction) writeFunction;
return new ObjectWriteFunction() {
@Override
public String getBindExpression()
{
// Syntax to force ClickHouse use Decimal parsing of floats
return format("?::Decimal(%d, %d)", decimalType.getPrecision(), decimalType.getScale());
}

@Override
public Class<Int128> getJavaType()
{
return Int128.class;
}

@Override
@SuppressWarnings("unchecked")
public void set(PreparedStatement statement, int index, Object value)
throws SQLException
{
objectWriteFunction.set(statement, index, value);
}

@Override
public void setNull(PreparedStatement statement, int index)
throws SQLException
{
objectWriteFunction.setNull(statement, index);
}
};
}

private ColumnMapping ipAddressColumnMapping(String clickhouseType)
{
return ColumnMapping.sliceMapping(
Expand Down
Loading

0 comments on commit 616f9cb

Please sign in to comment.