diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index 56177f98b94fc2..f323ac050e225a 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -44,6 +44,7 @@ import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.SliceWriteFunction; +import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; import io.trino.plugin.jdbc.aggregation.ImplementCount; @@ -77,6 +78,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.math.MathContext; +import java.math.RoundingMode; import java.net.InetAddress; import java.net.UnknownHostException; import java.sql.Connection; @@ -662,24 +664,24 @@ public Optional toColumnMapping(ConnectorSession session, Connect case Types.DOUBLE: return Optional.of(doubleColumnMapping()); + case Types.NUMERIC: case Types.DECIMAL: int decimalDigits = typeHandle.requiredDecimalDigits(); int precision = typeHandle.requiredColumnSize(); - ColumnMapping decimalColumnMapping; if (getDecimalRounding(session) == ALLOW_OVERFLOW && precision > Decimals.MAX_PRECISION) { int scale = Math.min(decimalDigits, getDecimalDefaultScale(session)); - decimalColumnMapping = decimalColumnMapping(createDecimalType(Decimals.MAX_PRECISION, scale), getDecimalRoundingMode(session)); + return Optional.of(customDecimalColumnMapping(createDecimalType(Decimals.MAX_PRECISION, scale), getDecimalRoundingMode(session))); } - else { - decimalColumnMapping = decimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0))); + + // At the moment ClickHouse doesn't support negative scales and scales + // larger than precision, handle it nevertheless. + precision = precision + max(-decimalDigits, 0); // Map decimal(p, -s) (negative scale) to decimal(p+s, 0). + if (precision > Decimals.MAX_PRECISION) { + break; // Ignore column } - return Optional.of(ColumnMapping.mapping( - decimalColumnMapping.getType(), - decimalColumnMapping.getReadFunction(), - decimalColumnMapping.getWriteFunction(), - // TODO (https://github.com/trinodb/trino/issues/7100) fix, enable and test decimal pushdown - DISABLE_PUSHDOWN)); + + return Optional.of(customDecimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0)), UNNECESSARY)); case Types.DATE: return Optional.of(dateColumnMappingUsingLocalDate(getClickHouseServerVersion(session))); @@ -894,6 +896,78 @@ private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction() }; } + private ColumnMapping customDecimalColumnMapping(DecimalType decimalType, RoundingMode roundingMode) + { + ColumnMapping nativeMapping = decimalColumnMapping(decimalType, roundingMode); + return ColumnMapping.mapping( + nativeMapping.getType(), + nativeMapping.getReadFunction(), + customDecimalWriteFunction(decimalType, nativeMapping.getWriteFunction()), + FULL_PUSHDOWN); + } + + private WriteFunction customDecimalWriteFunction(DecimalType decimalType, WriteFunction writeFunction) + { + checkArgument(writeFunction instanceof LongWriteFunction || writeFunction instanceof ObjectWriteFunction, + "writeFunction must be LongWriteFunction or ObjectWriteFunction"); + if (writeFunction instanceof LongWriteFunction) { + LongWriteFunction longWriteFunction = (LongWriteFunction) writeFunction; + return new LongWriteFunction() { + @Override + public String getBindExpression() + { + // Syntax to force ClickHouse use Decimal parsing of floats + return format("?::Decimal(%d, %d)", decimalType.getPrecision(), decimalType.getScale()); + } + + @Override + public void set(PreparedStatement statement, int index, long value) + throws SQLException + { + longWriteFunction.set(statement, index, value); + } + + @Override + public void setNull(PreparedStatement statement, int index) + throws SQLException + { + longWriteFunction.setNull(statement, index); + } + }; + } + + ObjectWriteFunction objectWriteFunction = (ObjectWriteFunction) writeFunction; + return new ObjectWriteFunction() { + @Override + public String getBindExpression() + { + // Syntax to force ClickHouse use Decimal parsing of floats + return format("?::Decimal(%d, %d)", decimalType.getPrecision(), decimalType.getScale()); + } + + @Override + public Class getJavaType() + { + return Int128.class; + } + + @Override + @SuppressWarnings("unchecked") + public void set(PreparedStatement statement, int index, Object value) + throws SQLException + { + objectWriteFunction.set(statement, index, value); + } + + @Override + public void setNull(PreparedStatement statement, int index) + throws SQLException + { + objectWriteFunction.setNull(statement, index); + } + }; + } + private ColumnMapping ipAddressColumnMapping(String clickhouseType) { return ColumnMapping.sliceMapping( diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java index 2eb5bd0c90525a..d2b16dba6689c0 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseTypeMapping.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; +import io.trino.plugin.jdbc.UnsupportedTypeHandling; +import io.trino.spi.type.Decimals; import io.trino.spi.type.TimeZoneKey; import io.trino.spi.type.UuidType; import io.trino.testing.AbstractTestQueryFramework; @@ -24,6 +26,7 @@ import io.trino.testing.datatype.CreateAsSelectDataSetup; import io.trino.testing.datatype.DataSetup; import io.trino.testing.datatype.SqlDataTypeTest; +import io.trino.testing.sql.JdbcSqlExecutor; import io.trino.testing.sql.SqlExecutor; import io.trino.testing.sql.TestTable; import io.trino.testing.sql.TrinoSqlExecutor; @@ -31,6 +34,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import java.math.RoundingMode; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; @@ -40,6 +44,13 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static io.trino.plugin.clickhouse.ClickHouseQueryRunner.TPCH_SCHEMA; +import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW; +import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.STRICT; +import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.DECIMAL_DEFAULT_SCALE; +import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.DECIMAL_MAPPING; +import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.DECIMAL_ROUNDING_MODE; +import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.UNSUPPORTED_TYPE_HANDLING; +import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; @@ -56,7 +67,10 @@ import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.type.IpAddressType.IPADDRESS; import static java.lang.String.format; +import static java.math.RoundingMode.HALF_UP; +import static java.math.RoundingMode.UNNECESSARY; import static java.time.ZoneOffset.UTC; +import static java.util.Arrays.asList; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @TestInstance(PER_CLASS) @@ -425,18 +439,267 @@ public void testDecimal() .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") .addRoundTrip("decimal(38, 0)", "CAST('27182818284590452353602874713526624977' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('27182818284590452353602874713526624977' AS decimal(38, 0))") .addRoundTrip("decimal(38, 0)", "CAST('-27182818284590452353602874713526624977' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('-27182818284590452353602874713526624977' AS decimal(38, 0))") - + .addRoundTrip("decimal(38, 38)", "CAST('0.27182818284590452353602874713526624977' AS decimal(38, 38))", createDecimalType(38, 38), "CAST('0.27182818284590452353602874713526624977' AS decimal(38, 38))") + .addRoundTrip("decimal(38, 38)", "CAST('-0.27182818284590452353602874713526624977' AS decimal(38, 38))", createDecimalType(38, 38), "CAST('-0.27182818284590452353602874713526624977' AS decimal(38, 38))") .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_decimal")) - - .addRoundTrip("decimal(3, 1)", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") - .addRoundTrip("decimal(30, 5)", "NULL", createDecimalType(30, 5), "CAST(NULL AS decimal(30,5))") - - .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); + .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")) + .execute(getQueryRunner(), trinoCreateAndInsert("test_decimal")); SqlDataTypeTest.create() + .addRoundTrip("Nullable(decimal(3, 1))", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") .addRoundTrip("Nullable(decimal(3, 1))", "NULL", createDecimalType(3, 1), "CAST(NULL AS decimal(3,1))") .addRoundTrip("Nullable(decimal(30, 5))", "NULL", createDecimalType(30, 5), "CAST(NULL AS decimal(30,5))") .execute(getQueryRunner(), clickhouseCreateAndInsert("tpch.test_nullable_decimal")); + + SqlDataTypeTest.create() + .addRoundTrip(format("Decimal(%d, 5)", Decimals.MAX_PRECISION + 1), "1.1", createDecimalType(Decimals.MAX_PRECISION, 5), format("CAST(1.1 AS DECIMAL(%d, 5))", Decimals.MAX_PRECISION)) + .execute(getQueryRunner(), sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 5), clickhouseCreateAndInsert("tpch.test_unspecified_decimal")); + } + + @Test + public void testDecimalExceedingPrecisionIsIgnored() + { + testUnsupportedDataTypeIsIgnored("decimal(50,0)", "'12345678901234567890123456789012345678901234567890'"); + } + + @Test + public void testDecimalExceedingPrecisionMaxConvertedToVarchar() + { + testUnsupportedDataTypeConvertedToVarchar( + getSession(), + "Nullable(Decimal(50,0))", + "numeric", + "12345678901234567890123456789012345678901234567890", + "'12345678901234567890123456789012345678901234567890'"); + } + + protected Session sessionWithDecimalMappingAllowOverflow(RoundingMode roundingMode, int scale) + { + return Session.builder(getSession()) + .setCatalogSessionProperty("clickhouse", DECIMAL_MAPPING, ALLOW_OVERFLOW.name()) + .setCatalogSessionProperty("clickhouse", DECIMAL_ROUNDING_MODE, roundingMode.name()) + .setCatalogSessionProperty("clickhouse", DECIMAL_DEFAULT_SCALE, Integer.valueOf(scale).toString()) + .build(); + } + + protected Session sessionWithDecimalMappingStrict(UnsupportedTypeHandling unsupportedTypeHandling) + { + return Session.builder(getSession()) + .setCatalogSessionProperty("clickhouse", DECIMAL_MAPPING, STRICT.name()) + .setCatalogSessionProperty("clickhouse", UNSUPPORTED_TYPE_HANDLING, unsupportedTypeHandling.name()) + .build(); + } + + public void testUnsupportedDataTypeIsIgnored(String dataType, String dataValue) + { + JdbcSqlExecutor jse = new JdbcSqlExecutor(clickhouseServer.getJdbcUrl()); + try (TestTable table = new TestTable( + jse, + "tpch.test_unsupported_decimal", + format("(i Int32, ut %s) ENGINE=Log", dataType), + ImmutableList.of("1, " + dataValue))) { + assertQuery(format("SELECT * FROM %s", table.getName()), "VALUES 1"); + assertQuery(format("DESC %s", table.getName()), "VALUES ('i', 'integer','', '')"); // no 'unsupported_column' + + assertUpdate(format("INSERT INTO %s VALUES 3", table.getName()), 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES 1, 3"); + } + } + + private void testUnsupportedDataTypeConvertedToVarchar(Session session, String dataTypeName, String internalDataTypeName, String databaseValue, String trinoValue) + { + JdbcSqlExecutor jse = new JdbcSqlExecutor(clickhouseServer.getJdbcUrl()); + try (TestTable table = new TestTable( + jse, + "tpch.unsupported_type", + format("(key Int32, unsupported_column %s) Engine=Log", dataTypeName), + ImmutableList.of("1, NULL", "2, " + databaseValue))) { + Session convertToVarchar = Session.builder(session) + .setCatalogSessionProperty("clickhouse", UNSUPPORTED_TYPE_HANDLING, CONVERT_TO_VARCHAR.name()) + .build(); + assertQuery( + convertToVarchar, + "SELECT * FROM " + table.getName(), + format("VALUES (1, NULL), (2, %s)", trinoValue)); + assertQuery( + convertToVarchar, + format("SELECT key FROM %s WHERE unsupported_column = %s", table.getName(), trinoValue), + "VALUES 2"); + assertQuery( + convertToVarchar, + "DESC " + table.getName(), + "VALUES " + + "('key', 'integer', '', ''), " + + "('unsupported_column', 'varchar', '', '')"); + assertUpdate( + convertToVarchar, + format("INSERT INTO %s (key, unsupported_column) VALUES (3, NULL)", table.getName()), + 1); + assertQueryFails( + convertToVarchar, + format("INSERT INTO %s (key, unsupported_column) VALUES (4, %s)", table.getName(), trinoValue), + ".*Underlying type that is mapped to VARCHAR is not supported for INSERT.*"); + assertUpdate( + convertToVarchar, + format("INSERT INTO %s (key) VALUES 5", table.getName()), + 1); + assertQuery( + convertToVarchar, + "SELECT * FROM " + table.getName(), + format("VALUES (1, NULL), (2, %s), (3, NULL), (5, NULL)", trinoValue)); + } + } + + @Test + public void testDecimalExceedingPrecisionMaxWithExceedingIntegerValues() + { + JdbcSqlExecutor jse = new JdbcSqlExecutor(clickhouseServer.getJdbcUrl()); + + try (TestTable testTable = new TestTable(jse, "tpch.test_exceeding_max_decimal", + "(d_col decimal(65,25)) Engine=Log", + asList("1234567890123456789012345678901234567890.123456789", "-1234567890123456789012345678901234567890.123456789"))) { + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,0)')"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + "SELECT d_col FROM " + testTable.getName(), + "Rounding necessary"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 0), + "SELECT d_col FROM " + testTable.getName(), + "Decimal overflow"); + assertQuery( + sessionWithDecimalMappingStrict(CONVERT_TO_VARCHAR), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'varchar')"); + assertQuery( + sessionWithDecimalMappingStrict(CONVERT_TO_VARCHAR), + "SELECT d_col FROM " + testTable.getName(), + "VALUES ('1234567890123456789012345678901234567890.1234567890000000000000000'), ('-1234567890123456789012345678901234567890.1234567890000000000000000')"); + } + } + + @Test + public void testDecimalExceedingPrecisionMaxWithNonExceedingIntegerValues() + { + JdbcSqlExecutor jse = new JdbcSqlExecutor(clickhouseServer.getJdbcUrl()); + + try (TestTable testTable = new TestTable( + jse, + "tpch.test_exceeding_max_decimal", + "(d_col decimal(60,20)) Engine=Log", + asList("123456789012345678901234567890.123456789012345", "-123456789012345678901234567890.123456789012345"))) { + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,0)')"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + "SELECT d_col FROM " + testTable.getName(), + "Rounding necessary"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 0), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (123456789012345678901234567890), (-123456789012345678901234567890)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 8), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,8)')"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 8), + "SELECT d_col FROM " + testTable.getName(), + "Rounding necessary"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 8), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (123456789012345678901234567890.12345679), (-123456789012345678901234567890.12345679)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 22), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,20)')"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 20), + "SELECT d_col FROM " + testTable.getName(), + "Decimal overflow"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 9), + "SELECT d_col FROM " + testTable.getName(), + "Decimal overflow"); + assertQuery( + sessionWithDecimalMappingStrict(CONVERT_TO_VARCHAR), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'varchar')"); + assertQuery( + sessionWithDecimalMappingStrict(CONVERT_TO_VARCHAR), + "SELECT d_col FROM " + testTable.getName(), + "VALUES ('123456789012345678901234567890.12345678901234500000'), ('-123456789012345678901234567890.12345678901234500000')"); + } + } + + @Test + public void testDecimalExceedingPrecisionMaxWithSupportedValues() + { + testDecimalExceedingPrecisionMaxWithSupportedValues(40, 8); + testDecimalExceedingPrecisionMaxWithSupportedValues(50, 10); + } + + private void testDecimalExceedingPrecisionMaxWithSupportedValues(int typePrecision, int typeScale) + { + JdbcSqlExecutor jse = new JdbcSqlExecutor(clickhouseServer.getJdbcUrl()); + try (TestTable testTable = new TestTable( + jse, + "tpch.test_exceeding_max_decimal", + format("(d_col decimal(%d,%d)) Engine=Log", typePrecision, typeScale), + asList("12.01", "-12.01", "123", "-123", "1.12345678", "-1.12345678"))) { + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,0)')"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + "SELECT d_col FROM " + testTable.getName(), + "Rounding necessary"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 0), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (12), (-12), (123), (-123), (1), (-1)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 3), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,3)')"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 3), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (12.01), (-12.01), (123), (-123), (1.123), (-1.123)"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 3), + "SELECT d_col FROM " + testTable.getName(), + "Rounding necessary"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 8), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col', 'decimal(38,8)')"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 8), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (12.01), (-12.01), (123), (-123), (1.12345678), (-1.12345678)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 9), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (12.01), (-12.01), (123), (-123), (1.12345678), (-1.12345678)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 8), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (12.01), (-12.01), (123), (-123), (1.12345678), (-1.12345678)"); + } + } + + protected String omitDatabasePrefix(String tableName) + { + String[] components = tableName.split("\\."); + return components[components.length - 1]; } @Test diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index f3a07a230e1c37..0711381df0a7bd 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -897,6 +897,38 @@ public void testEnumPredicatePushdown() } } + @Test + public void testDecimalPredicatePushdown() + { + try (TestTable table = new TestTable( + onRemoteDatabase(), + "tpch.test_decimal_pushdown", + "(short_decimal decimal(9, 3), long_decimal decimal(30, 10)) Engine=Log", + List.of("123.321, 123456789.987654321"))) { + assertThat(query("SELECT * FROM " + table.getName() + " WHERE short_decimal <= 124")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE short_decimal <= 124")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE long_decimal <= 123456790")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE short_decimal <= 123.321")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE long_decimal <= 123456789.987654321")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE short_decimal = 123.321")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE long_decimal = 123456789.987654321")) + .matches("VALUES (CAST(123.321 AS decimal(9,3)), CAST(123456789.987654321 AS decimal(30, 10)))") + .isFullyPushedDown(); + } + } + @Override protected OptionalInt maxTableNameLength() { diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java index 80556e674c5392..57f1688f58d791 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseLatestTypeMapping.java @@ -14,9 +14,16 @@ package io.trino.plugin.clickhouse; import io.trino.testing.QueryRunner; +import io.trino.testing.sql.JdbcSqlExecutor; +import io.trino.testing.sql.TestTable; +import org.junit.jupiter.api.Test; import static io.trino.plugin.clickhouse.ClickHouseQueryRunner.createClickHouseQueryRunner; import static io.trino.plugin.clickhouse.TestingClickHouseServer.CLICKHOUSE_LATEST_IMAGE; +import static java.lang.String.format; +import static java.math.RoundingMode.HALF_UP; +import static java.math.RoundingMode.UNNECESSARY; +import static java.util.Arrays.asList; public class TestClickHouseLatestTypeMapping extends BaseClickHouseTypeMapping @@ -28,4 +35,63 @@ protected QueryRunner createQueryRunner() clickhouseServer = closeAfterClass(new TestingClickHouseServer(CLICKHOUSE_LATEST_IMAGE)); return createClickHouseQueryRunner(clickhouseServer); } + + @Test + public void testDecimalUnspecifiedPrecisionWithValues() + { + JdbcSqlExecutor jse = new JdbcSqlExecutor(clickhouseServer.getJdbcUrl()); + + // https://github.com/clickhouse/ClickHouse/pull/53328 restores the support of + // Decimal with unspecified precision and scale. + try (TestTable testTable = new TestTable( + jse, + "tpch.test_var_decimal", + "(d_col decimal) Engine=Log", + asList("1.12", "123456.789", "-1.12", "-123456.789"))) { + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col','decimal(10,0)')"); + + // Excessive digits in a fraction are discarded (not rounded). Excessive digits in integer part will lead to an exception. + // Danger: Overflow check is not implemented for Decimal128 and Decimal256. In case of overflow incorrect result is returned, no exception is thrown. + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 0), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (1), (123456), (-1), (-123456)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 0), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (1), (123456), (-1), (-123456)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 1), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col','decimal(10,0)')"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 1), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (1), (123456), (-1), (-123456)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(HALF_UP, 2), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (1), (123456), (-1), (-123456)"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 3), + format("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'tpch' AND table_name = '%s'", omitDatabasePrefix(testTable.getName())), + "VALUES ('d_col','decimal(10,0)')"); + assertQuery( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 3), + "SELECT d_col FROM " + testTable.getName(), + "VALUES (1), (123456), (-1), (-123456)"); + + // Check that integer part overflow leads to an exception + assertQuerySucceeds( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 2), + "INSERT INTO " + testTable.getName() + " VALUES (1234567890)"); + assertQueryFails( + sessionWithDecimalMappingAllowOverflow(UNNECESSARY, 2), + "INSERT INTO " + testTable.getName() + " VALUES (12345678901)", + "Cannot cast.*"); + } + } }