Skip to content

Commit

Permalink
Support cast projection pushdown in oracle
Browse files Browse the repository at this point in the history
Co-authored-by: Sasha Sheikin <[email protected]>
  • Loading branch information
krvikash and ssheikin committed Jul 18, 2024
1 parent 02b5838 commit 1acef9d
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.base.projection.ProjectFunctionRewriter;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.BooleanWriteFunction;
Expand Down Expand Up @@ -165,10 +167,10 @@ public class OracleClient
private static final int MAX_BYTES_PER_CHAR = 4;

private static final int ORACLE_VARCHAR2_MAX_BYTES = 4000;
private static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int ORACLE_CHAR_MAX_BYTES = 2000;
private static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127;

Expand Down Expand Up @@ -215,6 +217,7 @@ public class OracleClient

private final boolean synonymsEnabled;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final ProjectFunctionRewriter<JdbcExpression, ParameterizedExpression> projectFunctionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;

@Inject
Expand Down Expand Up @@ -242,6 +245,12 @@ public OracleClient(
.add(new RewriteStringComparison())
.build();

this.projectFunctionRewriter = new ProjectFunctionRewriter<>(
this.connectorExpressionRewriter,
ImmutableSet.<ProjectFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
.add(new RewriteCast(this::toWriteMapping))
.build());

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
connectorExpressionRewriter,
Expand Down Expand Up @@ -553,6 +562,12 @@ public Optional<ParameterizedExpression> convertPredicate(ConnectorSession sessi
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}

@Override
public Optional<JdbcExpression> convertProjection(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return projectFunctionRewriter.rewrite(session, expression, assignments);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(OracleTypes.NUMBER, Optional.of("NUMBER"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.oracle;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.WriteMapping;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.type.CharType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import oracle.jdbc.OracleTypes;

import java.util.Optional;
import java.util.function.BiFunction;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.oracle.OracleClient.ORACLE_CHAR_MAX_CHARS;
import static io.trino.plugin.oracle.OracleClient.ORACLE_VARCHAR2_MAX_CHARS;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class RewriteCast
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression>
{
private static final Capture<ConnectorExpression> VALUE = newCapture();
private final BiFunction<ConnectorSession, Type, WriteMapping> toWriteMapping;

public RewriteCast(BiFunction<ConnectorSession, Type, WriteMapping> toWriteMapping)
{
this.toWriteMapping = requireNonNull(toWriteMapping, "toWriteMapping is null");
}

@Override
public Pattern<Call> getPattern()
{
return call()
.with(functionName().equalTo(CAST_FUNCTION_NAME))
.with(argumentCount().equalTo(1))
.with(argument(0).matching(expression().capturedAs(VALUE)));
}

@Override
public Optional<JdbcExpression> rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
Call call = (Call) projectionExpression;
Type trinoType = call.getType();
ConnectorExpression capturedValue = captures.get(VALUE);

Optional<ParameterizedExpression> value = context.rewriteExpression(capturedValue);
if (value.isEmpty()) {
// if argument is a call chain that can't be rewritten, then we can't push it down
return Optional.empty();
}
Optional<JdbcTypeHandle> targetTypeHandle = toJdbcTypeHandle(trinoType);
if (targetTypeHandle.isEmpty()) {
return Optional.empty();
}

String targetType = toWriteMapping.apply(context.getSession(), trinoType).getDataType();

return Optional.of(new JdbcExpression(
format("CAST(%s AS %s)", value.get().expression(), targetType),
value.get().parameters(),
targetTypeHandle.get()));
}

private static Optional<JdbcTypeHandle> toJdbcTypeHandle(Type type)
{
if (type instanceof CharType charType) {
if (charType.getLength() > ORACLE_CHAR_MAX_CHARS) {
return Optional.empty();
}
return Optional.of(new JdbcTypeHandle(OracleTypes.CHAR, Optional.of(charType.getBaseName()), Optional.of(charType.getLength()), Optional.empty(), Optional.empty(), Optional.empty()));
}
if (type instanceof VarcharType varcharType) {
if (varcharType.isUnbounded() || varcharType.getBoundedLength() > ORACLE_VARCHAR2_MAX_CHARS) {
return Optional.empty();
}
return Optional.of(new JdbcTypeHandle(OracleTypes.VARCHAR, Optional.of("VARCHAR2"), varcharType.getLength(), Optional.empty(), Optional.empty(), Optional.empty()));
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.oracle;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
Expand All @@ -25,17 +26,20 @@
import io.trino.testing.sql.TestView;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.oracle.TestingOracleServer.TEST_USER;
import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -191,6 +195,119 @@ public void testTimestampOutOfPrecisionRounded()
assertUpdate("DROP TABLE " + tableName);
}

@Test
public void testCastProjectionPushdown()
{
for (CastTestCase testCase : ImmutableList.of(
new CastTestCase("char(10)", "char(20)"),
new CastTestCase("char(10)", "varchar(20)"),
new CastTestCase("char(501)", "varchar(20)"),
new CastTestCase("char(501)", "varchar"), // No cast will be used during execution, as both the type will be converted to nclob
new CastTestCase("varchar(10)", "varchar(20)"),
new CastTestCase("varchar(1001)", "varchar(20)"),
new CastTestCase("varchar(1001)", "varchar"), // No cast will be used during execution, as both the type will be converted to nclob
new CastTestCase("varchar", "varchar(20)"))) {
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_cast_", "(a %s)".formatted(testCase.fromType()))) {
assertThat(query("SELECT CAST(a AS %s) FROM %s".formatted(testCase.castType(), table.getName())))
.isFullyPushedDown();
}
}

for (CastTestCase testCase : ImmutableList.of(
new CastTestCase("integer", "tinyint"),
new CastTestCase("integer", "smallint"),
new CastTestCase("integer", "bigint"),
new CastTestCase("integer", "real"),
new CastTestCase("integer", "double"),
new CastTestCase("integer", "decimal(15)"),
new CastTestCase("integer", "decimal(10, 2)"),
new CastTestCase("integer", "decimal(30, 2)"),
new CastTestCase("smallint", "integer"),
new CastTestCase("char(10)", "char(501)"),
new CastTestCase("char(510)", "char(520)"),
new CastTestCase("char(510)", "varchar(1001)"),
new CastTestCase("char(10)", "varchar(1001)"),
new CastTestCase("char(10)", "varchar"),
new CastTestCase("varchar(10)", "varchar(1001)"),
new CastTestCase("varchar(10)", "varchar"),
new CastTestCase("varchar(1001)", "varchar(1010)"),
new CastTestCase("varchar(10)", "varbinary"),
new CastTestCase("timestamp", "date"),
new CastTestCase("date", "timestamp"),
new CastTestCase("date", "timestamp with time zone"))) {
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_cast_", "(a %s)".formatted(testCase.fromType()))) {
assertThat(query("SELECT CAST(a AS %s) FROM %s".formatted(testCase.castType(), table.getName())))
.isNotFullyPushedDown(ProjectNode.class);
}
}
}

@Test
public void testImplicitCastJoinPushdown()
{
String leftTable = "left_table_" + randomNameSuffix();
computeActual("CREATE TABLE %s (id int, varchar_50 varchar(50))".formatted(leftTable));
assertUpdate("insert into %s values (1, 'India')".formatted(leftTable), 1);
assertUpdate("insert into %s values (2, 'Poland')".formatted(leftTable), 1);

String rightTable = "right_table_" + randomNameSuffix();
computeActual("CREATE TABLE %s (varchar_100 varchar(100), capital varchar)".formatted(rightTable));
assertUpdate("insert into %s values ('India', 'New Delhi')".formatted(rightTable), 1);
assertUpdate("insert into %s values ('France', 'Paris')".formatted(rightTable), 1);

Session session = joinPushdownEnabled(getSession());

assertJoinFullyPushedDown(session, leftTable, rightTable, "LEFT JOIN", "l.varchar_50 = r.varchar_100", Arrays.asList(1, 2));
assertJoinFullyPushedDown(session, leftTable, rightTable, "RIGHT JOIN", "l.varchar_50 = r.varchar_100", Arrays.asList(1, null));
assertJoinFullyPushedDown(session, leftTable, rightTable, "INNER JOIN", "l.varchar_50 = r.varchar_100", List.of(1));
assertJoinFullyPushedDown(session, leftTable, rightTable, "INNER JOIN", "r.varchar_100 = l.varchar_50", List.of(1));

assertUpdate("DROP TABLE " + leftTable);
assertUpdate("DROP TABLE " + rightTable);
}

@Test
public void testExplicitCastJoinPushdown()
{
String leftTable = "left_table_" + randomNameSuffix();
computeActual("CREATE TABLE %s (id int, varchar_50 varchar(50))".formatted(leftTable));
assertUpdate("insert into %s values (1, 'India')".formatted(leftTable), 1);
assertUpdate("insert into %s values (2, 'Poland')".formatted(leftTable), 1);

String rightTable = "right_table_" + randomNameSuffix();
computeActual("CREATE TABLE %s (varchar_100 varchar(100), capital varchar)".formatted(rightTable));
assertUpdate("insert into %s values ('India', 'New Delhi')".formatted(rightTable), 1);
assertUpdate("insert into %s values ('France', 'Paris')".formatted(rightTable), 1);

Session session = joinPushdownEnabled(getSession());

assertJoinFullyPushedDown(session, leftTable, rightTable, "LEFT JOIN", "CAST(l.varchar_50 AS VARCHAR(100)) = r.varchar_100", Arrays.asList(1, 2));
assertJoinFullyPushedDown(session, leftTable, rightTable, "LEFT JOIN", "l.varchar_50 = CAST(r.varchar_100 AS VARCHAR(50))", Arrays.asList(1, 2));
assertJoinFullyPushedDown(session, leftTable, rightTable, "LEFT JOIN", "CAST(l.varchar_50 AS VARCHAR(200)) = CAST(r.varchar_100 AS VARCHAR(200))", Arrays.asList(1, 2));

assertJoinFullyPushedDown(session, leftTable, rightTable, "RIGHT JOIN", "CAST(l.varchar_50 AS VARCHAR(100)) = r.varchar_100", Arrays.asList(1, null));
assertJoinFullyPushedDown(session, leftTable, rightTable, "RIGHT JOIN", "l.varchar_50 = CAST(r.varchar_100 AS VARCHAR(50))", Arrays.asList(1, null));
assertJoinFullyPushedDown(session, leftTable, rightTable, "RIGHT JOIN", "CAST(l.varchar_50 AS VARCHAR(200)) = CAST(r.varchar_100 AS VARCHAR(200))", Arrays.asList(1, null));

assertJoinFullyPushedDown(session, leftTable, rightTable, "INNER JOIN", "CAST(l.varchar_50 AS VARCHAR(100)) = r.varchar_100", List.of(1));
assertJoinFullyPushedDown(session, leftTable, rightTable, "INNER JOIN", "l.varchar_50 = CAST(r.varchar_100 AS VARCHAR(50))", List.of(1));
assertJoinFullyPushedDown(session, leftTable, rightTable, "INNER JOIN", "CAST(l.varchar_50 AS VARCHAR(200)) = CAST(r.varchar_100 AS VARCHAR(200))", List.of(1));

assertUpdate("DROP TABLE " + leftTable);
assertUpdate("DROP TABLE " + rightTable);
}

private void assertJoinFullyPushedDown(Session session, String leftTable, String rightTable, String joinType, String joinCondition, List<Integer> expectedOutput)
{
String expected = Joiner.on(",")
.join(expectedOutput.stream()
.map("(CAST(%s AS DECIMAL(10,0)))"::formatted)
.collect(toImmutableList()));
assertThat(query(session, "SELECT id FROM %s l %s %s r ON %s".formatted(leftTable, joinType, rightTable, joinCondition)))
.matches("VALUES " + expected)
.isFullyPushedDown();
}

@Test
@Override
public void testCharVarcharComparison()
Expand Down Expand Up @@ -500,4 +617,13 @@ protected String getUser()
{
return TEST_USER;
}

private record CastTestCase(String fromType, String castType)
{
private CastTestCase
{
requireNonNull(fromType, "fromType is null");
requireNonNull(castType, "castType is null");
}
}
}

0 comments on commit 1acef9d

Please sign in to comment.