diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/AbstractRewriteCast.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/AbstractRewriteCast.java new file mode 100644 index 00000000000000..4cf615c08f3e28 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/AbstractRewriteCast.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.projection.ProjectFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.Type; + +import java.util.Optional; +import java.util.function.BiFunction; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; +import static java.util.Objects.requireNonNull; + +public abstract class AbstractRewriteCast + implements ProjectFunctionRule +{ + private static final Capture VALUE = newCapture(); + + private final BiFunction toTargetType; + + protected abstract Optional toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType); + + protected String buildCast(@SuppressWarnings("unused") Type sourceType, @SuppressWarnings("unused") Type targetType, String expression, String castType) + { + return "CAST(%s AS %s)".formatted(expression, castType); + } + + public AbstractRewriteCast(BiFunction toTargetType) + { + this.toTargetType = requireNonNull(toTargetType, "toTargetType is null"); + } + + @Override + public Pattern getPattern() + { + return call() + .with(functionName().equalTo(CAST_FUNCTION_NAME)) + .with(argumentCount().equalTo(1)) + .with(argument(0).matching(variable().capturedAs(VALUE))); + } + + @Override + public Optional rewrite(ConnectorTableHandle handle, ConnectorExpression castExpression, Captures captures, RewriteContext context) + { + Variable variable = captures.get(VALUE); + JdbcTypeHandle sourceTypeJdbcHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle(); + Type targetType = castExpression.getType(); + Optional targetJdbcTypeHandle = toJdbcTypeHandle(sourceTypeJdbcHandle, targetType); + + if (targetJdbcTypeHandle.isEmpty()) { + return Optional.empty(); + } + + Optional value = context.rewriteExpression(variable); + if (value.isEmpty()) { + return Optional.empty(); + } + + Type sourceType = variable.getType(); + String castType = toTargetType.apply(context.getSession(), targetType); + + return Optional.of(new JdbcExpression( + buildCast(sourceType, targetType, value.get().expression(), castType), + value.get().parameters(), + targetJdbcTypeHandle.get())); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java new file mode 100644 index 00000000000000..b13a2e5f5c411c --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.sql.SqlExecutor; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public abstract class BaseJdbcCastPushdownTest + extends AbstractTestQueryFramework +{ + protected abstract String leftTable(); + + protected abstract String rightTable(); + + protected abstract SqlExecutor onRemoteDatabase(); + + protected abstract List supportedCastTypePushdown(); + + protected abstract List unsupportedCastTypePushdown(); + + protected abstract List failCast(); + + @Test + public void testProjectionPushdownWithCast() + { + for (CastTestCase testCase : supportedCastTypePushdown()) { + assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) + .isFullyPushedDown(); + } + + for (CastTestCase testCase : unsupportedCastTypePushdown()) { + assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + } + } + + @Test + public void testJoinPushdownWithCast() + { + for (CastTestCase testCase : supportedCastTypePushdown()) { + assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + } + + for (CastTestCase testCase : unsupportedCastTypePushdown()) { + assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .joinIsNotFullyPushedDown(); + } + } + + @Test + public void testCastFails() + { + for (CastTestCase testCase : failCast()) { + assertThatThrownBy(() -> getQueryRunner().execute("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) + .hasMessageMatching("(.*)Cannot cast (.*) to (.*)"); + } + } + + public record CastTestCase(String sourceColumn, String castType, Optional targetColumn) + { + public CastTestCase(String sourceColumn, String castType) + { + this(sourceColumn, castType, Optional.empty()); + } + + public CastTestCase + { + requireNonNull(sourceColumn, "sourceColumn is null"); + requireNonNull(castType, "castType is null"); + requireNonNull(targetColumn, "targetColumn is null"); + } + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/CastDataTypeTestTable.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/CastDataTypeTestTable.java new file mode 100644 index 00000000000000..5d9e4ded10baa5 --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/CastDataTypeTestTable.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TemporaryRelation; +import io.trino.testing.sql.TestTable; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.testing.Closeables.closeAll; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public final class CastDataTypeTestTable + implements TemporaryRelation +{ + private final List testCaseRows = new ArrayList<>(); + private final int rowCount; + + private TestTable testTable; + + private CastDataTypeTestTable(int rowCount) + { + this.rowCount = rowCount; + } + + public static CastDataTypeTestTable create(int rowCount) + { + return new CastDataTypeTestTable(rowCount); + } + + public CastDataTypeTestTable addColumn(String columnName, String columnType, List inputValues) + { + checkArgument(rowCount == inputValues.size(), "Expected input size: %s, but found: %s", rowCount, inputValues.size()); + testCaseRows.add(new TestCaseInput(columnName, columnType, inputValues)); + return this; + } + + public CastDataTypeTestTable execute(SqlExecutor sqlExecutor, String tableNamePrefix) + { + checkState(!testCaseRows.isEmpty(), "No test case rows"); + List columnsWithType = new ArrayList<>(); + for (TestCaseInput input : testCaseRows) { + columnsWithType.add("%s %s".formatted(input.columnName(), input.columnType())); + } + String tableDefinition = columnsWithType.stream() + .collect(joining(", ", "(", ")")); + + List rowsToInsert = IntStream.range(0, rowCount) + .mapToObj(rowId -> testCaseRows.stream() + .map(TestCaseInput::inputValues) + .map(rows -> rows.get(rowId)) + .map(String::valueOf) + .collect(joining(","))) + .collect(toImmutableList()); + testTable = new TestTable(sqlExecutor, tableNamePrefix, tableDefinition, rowsToInsert); + return this; + } + + @Override + public String getName() + { + return testTable.getName(); + } + + @Override + public void close() + { + try { + closeAll(testTable); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + public record TestCaseInput(String columnName, String columnType, List inputValues) + { + public TestCaseInput + { + requireNonNull(columnName, "columnName is null"); + requireNonNull(columnType, "columnType is null"); + inputValues = new ArrayList<>(requireNonNull(inputValues, "inputValues is null")); + } + } +} diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index ae2b4320714b2a..3691e141a03024 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -21,6 +21,8 @@ import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.mapping.IdentifierMapping; +import io.trino.plugin.base.projection.ProjectFunctionRewriter; +import io.trino.plugin.base.projection.ProjectFunctionRule; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.BooleanWriteFunction; @@ -165,10 +167,10 @@ public class OracleClient private static final int MAX_BYTES_PER_CHAR = 4; private static final int ORACLE_VARCHAR2_MAX_BYTES = 4000; - private static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR; + public static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR; private static final int ORACLE_CHAR_MAX_BYTES = 2000; - private static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR; + public static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR; private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127; @@ -215,6 +217,7 @@ public class OracleClient private final boolean synonymsEnabled; private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final ProjectFunctionRewriter projectFunctionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; private final Optional fetchSize; @@ -243,6 +246,12 @@ public OracleClient( .add(new RewriteStringComparison()) .build(); + this.projectFunctionRewriter = new ProjectFunctionRewriter<>( + this.connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new RewriteCast((session, type) -> toWriteMapping(session, type).getDataType())) + .build()); + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty()); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( connectorExpressionRewriter, @@ -559,6 +568,12 @@ public Optional convertPredicate(ConnectorSession sessi return connectorExpressionRewriter.rewrite(session, expression, assignments); } + @Override + public Optional convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map assignments) + { + return projectFunctionRewriter.rewrite(session, handle, expression, assignments); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(OracleTypes.NUMBER, Optional.of("NUMBER"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteCast.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteCast.java new file mode 100644 index 00000000000000..2885c6635b60b9 --- /dev/null +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteCast.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.oracle; + +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.AbstractRewriteCast; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.CharType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; +import oracle.jdbc.OracleTypes; + +import java.util.Optional; +import java.util.function.BiFunction; + +import static io.trino.plugin.oracle.OracleClient.ORACLE_CHAR_MAX_CHARS; +import static io.trino.plugin.oracle.OracleClient.ORACLE_VARCHAR2_MAX_CHARS; + +public class RewriteCast + extends AbstractRewriteCast +{ + public RewriteCast(BiFunction toTargetType) + { + super(toTargetType); + } + + @Override + protected String buildCast(Type sourceType, Type targetType, String expression, String castType) + { + if (sourceType instanceof CharType sourceCharType && targetType instanceof CharType targetCharType) { + if (sourceCharType.getLength() < targetCharType.getLength()) { + // Do not cast unnecessary with extra space padding when target char type has more length than source char type + return expression; + } + } + return "CAST(%s AS %s)".formatted(expression, castType); + } + + @Override + protected Optional toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType) + { + if (!pushdownSupported(sourceType, targetType)) { + return Optional.empty(); + } + + if (targetType instanceof CharType charType) { + return Optional.of(new JdbcTypeHandle(OracleTypes.CHAR, Optional.of(charType.getBaseName()), Optional.of(charType.getLength()), Optional.empty(), Optional.empty(), Optional.empty())); + } + if (targetType instanceof VarcharType varcharType) { + return Optional.of(new JdbcTypeHandle(OracleTypes.VARCHAR, Optional.of(varcharType.getBaseName()), varcharType.getLength(), Optional.empty(), Optional.empty(), Optional.empty())); + } + return Optional.empty(); + } + + private boolean pushdownSupported(JdbcTypeHandle sourceType, Type targetType) + { + if (targetType instanceof CharType charType) { + // Oracle will throw an error on casts to char(n>ORACLE_CHAR_MAX_CHARS) "ORA-00932: inconsistent datatypes: expected - got NCLOB" + return charType.getLength() <= ORACLE_CHAR_MAX_CHARS + && supportedSourceTypeToCastToChar(sourceType); + } + if (targetType instanceof VarcharType varcharType && !varcharType.isUnbounded()) { + // unbounded varchar and char(n>ORACLE_VARCHAR2_MAX_CHARS) gets written as nclob. + // pushdown does not happen when comparing nclob type variable, so skipping to pushdown the cast for nclob type variable. + return varcharType.getLength().orElseThrow() <= ORACLE_VARCHAR2_MAX_CHARS + && supportedSourceTypeToCastToVarchar(sourceType); + } + return false; + } + + private static boolean supportedSourceTypeToCastToChar(JdbcTypeHandle sourceType) + { + return switch (sourceType.jdbcType()) { + case OracleTypes.CHAR, + OracleTypes.VARCHAR, + OracleTypes.NCHAR, + OracleTypes.NVARCHAR, + OracleTypes.CLOB, + OracleTypes.NCLOB -> true; + default -> false; + }; + } + + private static boolean supportedSourceTypeToCastToVarchar(JdbcTypeHandle sourceType) + { + return switch (sourceType.jdbcType()) { + case OracleTypes.NUMBER, + OracleTypes.CHAR, + OracleTypes.VARCHAR, + OracleTypes.NCHAR, + OracleTypes.NVARCHAR, + OracleTypes.CLOB, + OracleTypes.NCLOB -> true; + default -> false; + }; + } +} diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleCastPushdownTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleCastPushdownTest.java new file mode 100644 index 00000000000000..4d1b4f2105bf90 --- /dev/null +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleCastPushdownTest.java @@ -0,0 +1,436 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.oracle; + +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.plugin.jdbc.BaseJdbcCastPushdownTest; +import io.trino.plugin.jdbc.CastDataTypeTestTable; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.testing.sql.TestTable; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class BaseOracleCastPushdownTest + extends BaseJdbcCastPushdownTest +{ + private CastDataTypeTestTable left; + private CastDataTypeTestTable right; + + @BeforeAll + public void setup() + { + left = closeAfterClass(CastDataTypeTestTable.create(3) + .addColumn("id", "number(10)", asList(11, 12, 13)) + .addColumn("c_number_3", "number(3)", asList(1, 2, null)) // tinyint in trino + .addColumn("c_number_5", "number(5)", asList(1, 2, null)) // smallint in trino + .addColumn("c_number_10", "number(10)", asList(1, 2, null)) // integer in trino + .addColumn("c_number_19", "number(19)", asList(1, 2, null)) // bigint in trino + .addColumn("c_float", "float", asList(1.23, 2.67, null)) // double in trino + .addColumn("c_float_5", "float(5)", asList(1.23, 2.67, null)) // double in trino + .addColumn("c_binary_float", "binary_float", asList(1.23, 2.67, null)) + .addColumn("c_binary_double", "binary_double", asList(1.23, 2.67, null)) + .addColumn("c_nan", "binary_double", asList("BINARY_DOUBLE_NAN", "BINARY_FLOAT_NAN", null)) + .addColumn("c_infinity", "binary_double", asList("BINARY_DOUBLE_INFINITY", "-BINARY_DOUBLE_INFINITY", null)) + .addColumn("c_number_15", "decimal(15)", asList(1, 2, null)) + .addColumn("c_number_10_2", "decimal(10, 2)", asList(1.23, 2.67, null)) + .addColumn("c_number_30_2", "decimal(30, 2)", asList(1.23, 2.67, null)) + .addColumn("c_char_10", "char(10)", asList("'India'", "'Poland'", null)) + .addColumn("c_char_50", "char(50)", asList("'India'", "'Poland'", null)) + .addColumn("c_char_501", "char(501)", asList("'India'", "'Poland'", null)) // greater than ORACLE_CHAR_MAX_CHARS + .addColumn("c_char_520", "char(520)", asList("'India'", "'Poland'", null)) // greater than ORACLE_CHAR_MAX_CHARS + .addColumn("c_nchar_10", "nchar(10)", asList("N'India'", "N'Poland'", null)) + .addColumn("c_varchar_10", "varchar2(10)", asList("'India'", "'Poland'", null)) + .addColumn("c_varchar_10_byte", "varchar2(10 byte)", asList("'India'", "'Poland'", null)) + .addColumn("c_varchar_50", "varchar2(50)", asList("'India'", "'Poland'", null)) + .addColumn("c_varchar_1001", "varchar2(1001)", asList("'India'", "'Poland'", null)) // greater than ORACLE_VARCHAR2_MAX_CHARS + .addColumn("c_varchar_1020", "varchar2(1020)", asList("'India'", "'Poland'", null)) // greater than ORACLE_VARCHAR2_MAX_CHARS + .addColumn("c_varchar_numeric", "varchar2(50)", asList("'123'", "'456'", null)) + .addColumn("c_varchar_decimal", "varchar2(50)", asList("'1.23'", "'2.67'", null)) + .addColumn("c_varchar_alpha_numeric", "varchar2(50)", asList("'H311o'", "'123Hey'", null)) + .addColumn("c_varchar_date", "varchar2(50)", asList("'2024-09-08'", "'2019-08-15'", null)) + .addColumn("c_varchar_timestamp", "varchar2(50)", asList("'2024-09-08 01:02:03.666'", "'2019-08-15 09:08:07.333'", null)) + .addColumn("c_varchar_timestamptz", "varchar2(50)", asList("'2024-09-08 01:02:03.666 +05:30'", "'2019-08-15 09:08:07.333 +05:30'", null)) + .addColumn("c_nvarchar_100", "nvarchar2(100)", asList("N'India'", "N'Poland'", null)) // varchar(p) in trino + .addColumn("c_clob", "clob", asList("'India'", "'Poland'", null)) // varchar in trino + .addColumn("c_nclob", "nclob", asList("N'India'", "N'Poland'", null)) // varchar in trino + .addColumn("c_blob", "blob", asList("HEXTORAW('496E646961')", "HEXTORAW('506F6C616E64')", null)) // varbinary in trino + .addColumn("c_raw_200", "raw(200)", asList("HEXTORAW('496E646961')", "HEXTORAW('506F6C616E64')", null)) // varbinary in trino + .addColumn("c_date", "date", asList("DATE '2024-09-08'", "DATE '2019-08-15'", null)) + .addColumn("c_timestamp", "timestamp", asList("TIMESTAMP '2024-09-08 01:02:03.666'", "TIMESTAMP '2019-08-15 09:08:07.333'", null)) + .addColumn("c_timestamptz", "timestamp with time zone", asList("TIMESTAMP '2024-09-08 01:02:03.666 +05:30'", "TIMESTAMP '2019-08-15 09:08:07.333 +05:30'", null)) + + // unsupported in trino + .addColumn("c_number", "number", asList(1, 2, null)) + .addColumn("c_timestamp_ltz", "timestamp with local time zone", asList("TIMESTAMP '2024-09-08 01:02:03.666'", "TIMESTAMP '2019-08-15 09:08:07.333'", null)) + .addColumn("c_interval_ym", "interval year to month", asList("INTERVAL '1-2' YEAR TO MONTH", "INTERVAL '3-4' YEAR TO MONTH", null)) + .addColumn("c_interval_ds", "interval day to second", asList("INTERVAL '10 11:12:13.456' DAY TO SECOND", "INTERVAL '10 12:13:14.456' DAY TO SECOND", null)) + .addColumn("c_long", "long", asList("'India'", "'Poland'", null)) + .addColumn("c_xmltype", "xmltype", asList("XMLTYPE('Sample XML-1')", "XMLTYPE('Sample XML-2')", null)) + .execute(onRemoteDatabase(), "left_table_")); + + // 2nd row value is different in right table than left table + right = closeAfterClass(CastDataTypeTestTable.create(3) + .addColumn("id", "number(10)", asList(21, 22, 23)) + .addColumn("c_number_3", "number(3)", asList(1, 22, null)) // tinyint in trino + .addColumn("c_number_5", "number(5)", asList(1, 22, null)) // smallint in trino + .addColumn("c_number_10", "number(10)", asList(1, 22, null)) // integer in trino + .addColumn("c_number_19", "number(19)", asList(1, 22, null)) // bigint in trino + .addColumn("c_float", "float", asList(1.23, 22.67, null)) // double in trino + .addColumn("c_float_5", "float(5)", asList(1.23, 22.67, null)) // double in trino + .addColumn("c_binary_float", "binary_float", asList(1.23, 22.67, null)) + .addColumn("c_binary_double", "binary_double", asList(1.23, 22.67, null)) + .addColumn("c_nan", "binary_double", asList("BINARY_DOUBLE_NAN", "BINARY_DOUBLE_NAN", null)) + .addColumn("c_infinity", "binary_double", asList("BINARY_DOUBLE_INFINITY", "BINARY_DOUBLE_INFINITY", null)) + .addColumn("c_number_15", "decimal(15)", asList(1, 22, null)) + .addColumn("c_number_10_2", "decimal(10, 2)", asList(1.23, 22.67, null)) + .addColumn("c_number_30_2", "decimal(30, 2)", asList(1.23, 22.67, null)) + .addColumn("c_char_10", "char(10)", asList("'India'", "'France'", null)) + .addColumn("c_char_50", "char(50)", asList("'India'", "'France'", null)) + .addColumn("c_char_501", "char(501)", asList("'India'", "'France'", null)) // greater than ORACLE_CHAR_MAX_CHARS + .addColumn("c_char_520", "char(520)", asList("'India'", "'France'", null)) // greater than ORACLE_CHAR_MAX_CHARS + .addColumn("c_nchar_10", "nchar(10)", asList("N'India'", "N'France'", null)) + .addColumn("c_varchar_10", "varchar2(10)", asList("'India'", "'France'", null)) + .addColumn("c_varchar_10_byte", "varchar2(10 byte)", asList("'India'", "'France'", null)) + .addColumn("c_varchar_50", "varchar2(50)", asList("'India'", "'France'", null)) + .addColumn("c_varchar_1001", "varchar2(1001)", asList("'India'", "'France'", null)) // greater than ORACLE_VARCHAR2_MAX_CHARS + .addColumn("c_varchar_1020", "varchar2(1020)", asList("'India'", "'France'", null)) // greater than ORACLE_VARCHAR2_MAX_CHARS + .addColumn("c_varchar_numeric", "varchar2(50)", asList("'123'", "'234'", null)) + .addColumn("c_varchar_decimal", "varchar2(50)", asList("'1.23'", "'22.67'", null)) + .addColumn("c_varchar_alpha_numeric", "varchar2(50)", asList("'H311o'", "'123Bye'", null)) + .addColumn("c_varchar_date", "varchar2(50)", asList("'2024-09-08'", "'2020-08-15'", null)) + .addColumn("c_varchar_timestamp", "varchar2(50)", asList("'2024-09-08 01:02:03.666'", "'2020-08-15 09:08:07.333'", null)) + .addColumn("c_varchar_timestamptz", "varchar2(50)", asList("'2024-09-08 01:02:03.666 +05:30'", "'2020-08-15 09:08:07.333 +05:30'", null)) + .addColumn("c_nvarchar_100", "nvarchar2(100)", asList("N'India'", "N'France'", null)) // varchar(p) in trino + .addColumn("c_clob", "clob", asList("'India'", "'France'", null)) // varchar in trino + .addColumn("c_nclob", "nclob", asList("N'India'", "N'France'", null)) // varchar in trino + .addColumn("c_blob", "blob", asList("HEXTORAW('496E646961')", "HEXTORAW('4672616E6365')", null)) // varbinary in trino + .addColumn("c_raw_200", "raw(200)", asList("HEXTORAW('496E646961')", "HEXTORAW('4672616E6365')", null)) // varbinary in trino + .addColumn("c_date", "date", asList("DATE '2024-09-08'", "DATE '2020-08-15'", null)) + .addColumn("c_timestamp", "timestamp", asList("TIMESTAMP '2024-09-08 01:02:03.666'", "TIMESTAMP '2020-08-15 09:08:07.333'", null)) + .addColumn("c_timestamptz", "timestamp with time zone", asList("TIMESTAMP '2024-09-08 01:02:03.666 +05:30'", "TIMESTAMP '2020-08-15 09:08:07.333 +05:30'", null)) + + // unsupported in trino + .addColumn("c_number", "number", asList(1, 22, null)) + .addColumn("c_timestamp_ltz", "timestamp with local time zone", asList("TIMESTAMP '2024-09-08 01:02:03.666'", "TIMESTAMP '2020-08-15 09:08:07.333'", null)) + .addColumn("c_interval_ym", "interval year to month", asList("INTERVAL '1-2' YEAR TO MONTH", "INTERVAL '4-5' YEAR TO MONTH", null)) + .addColumn("c_interval_ds", "interval day to second", asList("INTERVAL '10 11:12:13.456' DAY TO SECOND", "INTERVAL '11 12:13:14.456' DAY TO SECOND", null)) + .addColumn("c_long", "long", asList("'India'", "'France'", null)) + .addColumn("c_xmltype", "xmltype", asList("XMLTYPE('Sample XML-1')", "XMLTYPE('Sample XML-3')", null)) + .execute(onRemoteDatabase(), "right_table_")); + } + + @Override + protected String leftTable() + { + return left.getName(); + } + + @Override + protected String rightTable() + { + return right.getName(); + } + + @Test + public void testCastPushdownSpecialCase() + { + for (CastTestCase testCase : specialCaseNClob()) { + // Projection pushdown is supported, because trino converts the clob type to nclob thus cast is not required + assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) + .isFullyPushedDown(); + // join pushdown is not supported, because comparison between nclob is not pushdown + assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .joinIsNotFullyPushedDown(); + } + } + + @Test + public void testJoinPushdownWithNestedCast() + { + CastTestCase testCase = new CastTestCase("c_varchar_10", "varchar(100)", Optional.of("c_varchar_50")); + assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + } + + @Test + public void testAllJoinPushdownWithCast() + { + CastTestCase testCase = new CastTestCase("c_varchar_10", "varchar(50)", Optional.of("c_varchar_50")); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + + testCase = new CastTestCase("c_varchar_10", "varchar(10)", Optional.of("c_varchar_50")); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn().orElseThrow(), testCase.castType()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn().orElseThrow(), testCase.castType()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn().orElseThrow(), testCase.castType()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON l.%s = CAST(r.%s AS %s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.targetColumn().orElseThrow(), testCase.castType()))) + .isFullyPushedDown(); + + testCase = new CastTestCase("c_varchar_10", "varchar(200)", Optional.of("c_varchar_50")); + assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) + .isFullyPushedDown(); + } + + @Test + public void testCastPushdownClobSensitivity() + { + // Verify that clob/nclob join condition is not applied in a case-insensitive way + try (TestTable leftTable = new TestTable( + onRemoteDatabase(), + "l_clob_sensitivity_", + "(id int, c_varchar_50 varchar(50), c_clob clob, c_nclob nclob)", + asList("11, 'India', 'India', 'India'", "12, 'Poland', 'Poland', 'Poland'")); + TestTable rightTable = new TestTable( + onRemoteDatabase(), + "r_clob_sensitivity_", + "(id int, c_varchar_50 varchar(50), c_clob clob, c_nclob nclob)", + asList("21, 'INDIA', 'INDIA', 'INDIA'", "22, 'POLAND', 'POLAND', 'POLAND'", "23, 'India', 'India', 'India'"))) { + assertThat(query("SELECT r.id, r.c_nclob FROM %s l JOIN %s r ON CAST(l.c_nclob AS VARCHAR(50)) = r.c_varchar_50".formatted(leftTable.getName(), rightTable.getName()))) + .matches("VALUES (CAST(23 AS DECIMAL(38, 0)), VARCHAR 'India')") + .isFullyPushedDown(); + + assertThat(query("SELECT r.id, r.c_clob FROM %s l JOIN %s r ON CAST(l.c_clob AS VARCHAR(50)) = r.c_varchar_50".formatted(leftTable.getName(), rightTable.getName()))) + .matches("VALUES (CAST(23 AS DECIMAL(38, 0)), VARCHAR 'India')") + .isFullyPushedDown(); + } + } + + @Test + public void testCastPushdownDisabled() + { + Session sessionWithoutComplexExpressionPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "complex_expression_pushdown", "false") + .build(); + assertThat(query(sessionWithoutComplexExpressionPushdown, "SELECT CAST (c_varchar_10 AS VARCHAR(100)) FROM %s".formatted(leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + } + + @Test + public void testCastPushdownWithForcedTypedToVarchar() + { + // These column types are not supported by default by trino. These types are forced mapped to varchar. + assertThat(query("SELECT CAST(c_interval_ym AS VARCHAR(100)) FROM %s".formatted(leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + assertThat(query("SELECT CAST(c_timestamp_ltz AS VARCHAR(100)) FROM %s".formatted(leftTable()))) + .isNotFullyPushedDown(ProjectNode.class); + } + + @Override + protected List supportedCastTypePushdown() + { + return ImmutableList.of( + new CastTestCase("c_char_10", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_char_50", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_char_501", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_char_520", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_nchar_10", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_varchar_10", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_varchar_10_byte", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_varchar_1001", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_varchar_1020", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_nvarchar_100", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_clob", "char(50)", Optional.of("c_char_50")), + new CastTestCase("c_nclob", "char(50)", Optional.of("c_char_50")), + + new CastTestCase("c_char_10", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_char_50", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_char_501", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_char_520", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_nchar_10", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_varchar_10", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_varchar_10_byte", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_varchar_1001", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_varchar_1020", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_nvarchar_100", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_clob", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_nclob", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_3", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_5", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_10", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_19", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_15", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_10_2", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_number_30_2", "varchar(50)", Optional.of("c_varchar_50"))); + } + + @Override + protected List unsupportedCastTypePushdown() + { + return ImmutableList.of( + new CastTestCase("c_char_10", "char(501)", Optional.of("c_char_501")), + new CastTestCase("c_char_501", "char(520)", Optional.of("c_char_520")), + new CastTestCase("c_varchar_10", "char(501)", Optional.of("c_char_501")), + new CastTestCase("c_varchar_1001", "char(501)", Optional.of("c_char_501")), + new CastTestCase("c_clob", "char(501)", Optional.of("c_char_501")), + new CastTestCase("c_nclob", "char(501)", Optional.of("c_char_501")), + + new CastTestCase("c_char_10", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_char_501", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_varchar_10", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_varchar_1001", "varchar(1020)", Optional.of("c_varchar_1020")), + new CastTestCase("c_clob", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_nclob", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_3", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_5", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_10", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_19", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_float", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_float_5", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_binary_float", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_binary_double", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_nan", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_infinity", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_15", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_10_2", "varchar(1001)", Optional.of("c_varchar_1001")), + new CastTestCase("c_number_30_2", "varchar(1001)", Optional.of("c_varchar_1001")), + + new CastTestCase("c_char_10", "varchar", Optional.of("c_clob")), + new CastTestCase("c_char_501", "varchar", Optional.of("c_clob")), + new CastTestCase("c_varchar_10", "varchar", Optional.of("c_clob")), + new CastTestCase("c_varchar_1001", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_3", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_5", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_10", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_19", "varchar", Optional.of("c_clob")), + new CastTestCase("c_float", "varchar", Optional.of("c_clob")), + new CastTestCase("c_float_5", "varchar", Optional.of("c_clob")), + new CastTestCase("c_binary_float", "varchar", Optional.of("c_clob")), + new CastTestCase("c_binary_double", "varchar", Optional.of("c_clob")), + new CastTestCase("c_nan", "varchar", Optional.of("c_clob")), + new CastTestCase("c_infinity", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_15", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_10_2", "varchar", Optional.of("c_clob")), + new CastTestCase("c_number_30_2", "varchar", Optional.of("c_clob")), + + new CastTestCase("c_char_10", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_char_501", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_varchar_10", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_varchar_1001", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_3", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_5", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_10", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_19", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_float", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_float_5", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_binary_float", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_binary_double", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_nan", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_infinity", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_15", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_10_2", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_number_30_2", "varchar", Optional.of("c_nclob")), + + new CastTestCase("c_number_3", "tinyint", Optional.of("c_number_5")), + new CastTestCase("c_number_3", "smallint", Optional.of("c_number_10")), + new CastTestCase("c_number_3", "integer", Optional.of("c_number_19")), + new CastTestCase("c_number_3", "bigint", Optional.of("c_float")), + new CastTestCase("c_number_3", "real", Optional.of("c_float_5")), + new CastTestCase("c_number_3", "double", Optional.of("c_binary_float")), + new CastTestCase("c_number_3", "double", Optional.of("c_binary_double")), + new CastTestCase("c_number_3", "decimal(15)", Optional.of("c_number_15")), + new CastTestCase("c_number_3", "decimal(10, 2)", Optional.of("c_number_10_2")), + new CastTestCase("c_number_3", "decimal(30, 2)", Optional.of("c_number_30_2")), + new CastTestCase("c_varchar_10", "varbinary", Optional.of("c_blob")), + new CastTestCase("c_varchar_10", "varbinary", Optional.of("c_raw_200")), + new CastTestCase("c_timestamp", "date", Optional.of("c_date")), + new CastTestCase("c_timestamptz", "timestamp", Optional.of("c_timestamp")), + new CastTestCase("c_timestamp", "timestamp with time zone", Optional.of("c_timestamptz")), + + // When data inserted from Trino, below cases give mismatched value between pushdown + // and without pushdown, So not supporting cast pushdown for these cases + new CastTestCase("c_float", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_float_5", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_binary_float", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_binary_double", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_nan", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_infinity", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_date", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_timestamp", "varchar(50)", Optional.of("c_varchar_50")), + new CastTestCase("c_timestamptz", "varchar(50)", Optional.of("c_varchar_50"))); + } + + @Override + protected List failCast() + { + return ImmutableList.of( + new CastTestCase("c_number_3", "char(50)"), + new CastTestCase("c_number_5", "char(50)"), + new CastTestCase("c_number_10", "char(50)"), + new CastTestCase("c_number_19", "char(50)"), + new CastTestCase("c_float", "char(50)"), + new CastTestCase("c_float_5", "char(50)"), + new CastTestCase("c_binary_float", "char(50)"), + new CastTestCase("c_binary_double", "char(50)"), + new CastTestCase("c_number_15", "char(50)"), + new CastTestCase("c_number_10_2", "char(50)"), + new CastTestCase("c_number_30_2", "char(50)"), + new CastTestCase("c_date", "char(50)"), + new CastTestCase("c_timestamp", "char(50)"), + new CastTestCase("c_timestamptz", "char(50)"), + new CastTestCase("c_blob", "char(50)"), + new CastTestCase("c_raw_200", "char(50)"), + + new CastTestCase("c_number_3", "char(501)"), + new CastTestCase("c_number_5", "char(501)"), + new CastTestCase("c_number_10", "char(501)"), + new CastTestCase("c_number_19", "char(501)"), + new CastTestCase("c_float", "char(501)"), + new CastTestCase("c_float_5", "char(501)"), + new CastTestCase("c_binary_float", "char(501)"), + new CastTestCase("c_binary_double", "char(501)"), + new CastTestCase("c_number_15", "char(501)"), + new CastTestCase("c_number_10_2", "char(501)"), + new CastTestCase("c_number_30_2", "char(501)"), + new CastTestCase("c_date", "char(501)"), + new CastTestCase("c_timestamp", "char(501)"), + new CastTestCase("c_timestamptz", "char(501)"), + new CastTestCase("c_blob", "char(501)"), + new CastTestCase("c_raw_200", "char(501)"), + + new CastTestCase("c_blob", "varchar(50)"), + new CastTestCase("c_raw_200", "varchar(50)")); + } + + private static List specialCaseNClob() + { + // Trino converts clob type to nclob + return ImmutableList.of( + new CastTestCase("c_clob", "varchar", Optional.of("c_clob")), + new CastTestCase("c_nclob", "varchar", Optional.of("c_clob")), + new CastTestCase("c_clob", "varchar", Optional.of("c_nclob")), + new CastTestCase("c_nclob", "varchar", Optional.of("c_nclob"))); + } +} diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java index 21ceb79f87392d..d61a2bf4c29a0c 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java @@ -496,6 +496,39 @@ private void predicatePushdownTest(String oracleType, String oracleLiteral, Stri } } + @Test + public void testJoinPushdownWithImplicitCast() + { + try (TestTable leftTable = new TestTable(getQueryRunner()::execute, "left_table", "(id int, varchar_50 varchar(50))", ImmutableList.of("(1, 'India')", "(2, 'Poland')")); + TestTable rightTable = new TestTable(getQueryRunner()::execute, "right_table_", "(varchar_100 varchar(100), varchar_unbounded varchar)", ImmutableList.of("('India', 'Japan')", "('France', 'Poland')"))) { + String leftTableName = leftTable.getName(); + String rightTableName = rightTable.getName(); + Session session = joinPushdownEnabled(getSession()); + + // Implicit cast between bounded varchar + String joinWithBoundedVarchar = "SELECT id FROM %s l %s %s r ON l.varchar_50 = r.varchar_100".formatted(leftTableName, "%s", rightTableName); + assertThat(query(session, joinWithBoundedVarchar.formatted("LEFT JOIN"))) + .isFullyPushedDown(); + assertThat(query(session, joinWithBoundedVarchar.formatted("RIGHT JOIN"))) + .isFullyPushedDown(); + assertThat(query(session, joinWithBoundedVarchar.formatted("INNER JOIN"))) + .isFullyPushedDown(); + assertThat(query(session, joinWithBoundedVarchar.formatted("FULL JOIN"))) + .isFullyPushedDown(); + + // Implicit cast between bounded and unbounded varchar + String joinWithUnboundedVarchar = "SELECT id FROM %s l %s %s r ON l.varchar_50 = r.varchar_unbounded".formatted(leftTableName, "%s", rightTableName); + assertThat(query(session, joinWithUnboundedVarchar.formatted("LEFT JOIN"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session, joinWithUnboundedVarchar.formatted("RIGHT JOIN"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session, joinWithUnboundedVarchar.formatted("INNER JOIN"))) + .joinIsNotFullyPushedDown(); + assertThat(query(session, joinWithUnboundedVarchar.formatted("FULL JOIN"))) + .joinIsNotFullyPushedDown(); + } + } + protected String getUser() { return TEST_USER; diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCastPushdown.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCastPushdown.java new file mode 100644 index 00000000000000..ed533be72cc1b6 --- /dev/null +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleCastPushdown.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.oracle; + +import com.google.common.collect.ImmutableMap; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.SqlExecutor; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +public class TestOracleCastPushdown + extends BaseOracleCastPushdownTest +{ + private TestingOracleServer oracleServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + oracleServer = closeAfterClass(new TestingOracleServer()); + return OracleQueryRunner.builder(oracleServer) + .addConnectorProperties(ImmutableMap.builder() + .put("jdbc-types-mapped-to-varchar", "interval year(2) to month, timestamp(6) with local time zone") + .put("join-pushdown.enabled", "true") + .buildOrThrow()) + .build(); + } + + @Override + protected SqlExecutor onRemoteDatabase() + { + return new SqlExecutor() + { + @Override + public boolean supportsMultiRowInsert() + { + return false; + } + + @Override + public void execute(String sql) + { + oracleServer.execute(sql); + } + }; + } +}