Skip to content

Commit

Permalink
Support cast projection pushdown in oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
krvikash committed Jul 25, 2024
1 parent 02b5838 commit b44eb71
Show file tree
Hide file tree
Showing 7 changed files with 759 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.type.Type;

import java.util.Optional;
import java.util.function.BiFunction;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static java.util.Objects.requireNonNull;

public abstract class AbstractRewriteCast
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression>
{
private static final Capture<ConnectorExpression> VALUE = newCapture();
private final BiFunction<ConnectorSession, Type, String> toTargetType;

protected abstract Optional<JdbcTypeHandle> toJdbcTypeHandle(Type sourceType, Type targetType);
protected abstract String buildCast(String expression, String castType);

public AbstractRewriteCast(BiFunction<ConnectorSession, Type, String> toTargetType)
{
this.toTargetType = requireNonNull(toTargetType, "toTargetType is null");
}

@Override
public Pattern<Call> getPattern()
{
return call()
.with(functionName().equalTo(CAST_FUNCTION_NAME))
.with(argumentCount().equalTo(1))
.with(argument(0).matching(expression().capturedAs(VALUE)));
}

@Override
public Optional<JdbcExpression> rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
Call call = (Call) projectionExpression;
Type targetType = call.getType();
ConnectorExpression capturedValue = captures.get(VALUE);
Type sourceType = capturedValue.getType();

Optional<ParameterizedExpression> value = context.rewriteExpression(capturedValue);
if (value.isEmpty()) {
// if argument is a call chain that can't be rewritten, then we can't push it down
return Optional.empty();
}
Optional<JdbcTypeHandle> targetTypeHandle = toJdbcTypeHandle(sourceType, targetType);
if (targetTypeHandle.isEmpty()) {
return Optional.empty();
}

String castType = toTargetType.apply(context.getSession(), targetType);

return Optional.of(new JdbcExpression(
buildCast(value.get().expression(), castType),
value.get().parameters(),
targetTypeHandle.get()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.Session;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.query.QueryAssertions;
import io.trino.testing.QueryRunner;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.List;

import static io.trino.plugin.jdbc.CastDataTypeTest.CastTestCase;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

@TestInstance(PER_CLASS)
public abstract class BaseCastPushdownTest
{
protected final Session session;
protected final QueryRunner queryRunner;
protected final QueryAssertions assertions;
protected CastDataTypeTest table1;
protected CastDataTypeTest table2;

public BaseCastPushdownTest(Session session, QueryRunner queryRunner)
{
this.session = requireNonNull(session, "session is null");
this.queryRunner = requireNonNull(queryRunner, "queryRunner is null");
this.assertions = new QueryAssertions(queryRunner);
}

protected abstract void setupTable();
protected abstract List<CastTestCase> supportedCastTypePushdown();
protected abstract List<CastTestCase> unsupportedCastTypePushdown();
protected abstract List<CastTestCase> failCast();

@BeforeAll
public void setup()
{
setupTable();
}

@AfterAll
public void cleanup()
{
dropTable(table1);
dropTable(table2);
}

private void dropTable(CastDataTypeTest table)
{
if (table2 != null) {
queryRunner.execute("DROP TABLE IF EXISTS " + table.tableName());
}
}

@Test
public void testCastProjectionPushdown()
{
for (CastDataTypeTest.CastTestCase testCase : supportedCastTypePushdown()) {
assertThat(assertions.query(session, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), table1.tableName())))
.isFullyPushedDown();
}

for (CastTestCase testCase : unsupportedCastTypePushdown()) {
assertThat(assertions.query(session, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), table1.tableName())))
.isNotFullyPushedDown(ProjectNode.class);
}
}

@Test
public void testCastProjectionFails()
{
for (CastTestCase testCase : failCast()) {
assertThatThrownBy(() -> queryRunner.execute(session, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), table1.tableName())))
.hasMessageMatching("(.*)Cannot cast (.*) to (.*)");
}
}

@Test
public void testJoinPushdownWithCast()
{
// combination of supported types
for (CastTestCase testCase : supportedCastTypePushdown()) {
assertJoinFullyPushedDown(session, table1.tableName(), table2.tableName(), "JOIN", "CAST(l.%s AS %s) = r.%s".formatted(testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()));
}

// combination of unsupported types
for (CastTestCase testCase : unsupportedCastTypePushdown()) {
assertJoinNotFullyPushedDown(session, table1.tableName(), table2.tableName(), "CAST(l.%s AS %s) = r.%s".formatted(testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()));
}
}

protected void assertJoinFullyPushedDown(Session session, String leftTable, String rightTable, String joinType, String joinCondition)
{
assertJoin(session, leftTable, rightTable, joinType, joinCondition)
.isFullyPushedDown();
}

protected void assertJoinNotFullyPushedDown(Session session, String leftTable, String rightTable, String joinCondition)
{
assertJoin(session, leftTable, rightTable, "JOIN", joinCondition)
.joinIsNotFullyPushedDown();
}

protected QueryAssertions.QueryAssert assertJoin(Session session, String leftTable, String rightTable, String joinType, String joinCondition)
{
System.out.println("SELECT l.id FROM %s l %s %s r ON %s".formatted(leftTable, joinType, rightTable, joinCondition));
return assertThat(assertions.query(session, "SELECT l.id FROM %s l %s %s r ON %s".formatted(leftTable, joinType, rightTable, joinCondition)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.testing.QueryRunner;
import io.trino.testing.sql.TestTable;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class CastDataTypeTest
{
private final List<TestCaseInput> testCaseInputs = new ArrayList<>();

private OptionalInt inputSize = OptionalInt.empty();
private TestTable testTable;

private CastDataTypeTest() {}

public static CastDataTypeTest create()
{
return new CastDataTypeTest();
}

public CastDataTypeTest addColumn(String columnName, String columnType, List<Object> inputValues)
{
if (inputSize.isEmpty()) {
inputSize = OptionalInt.of(inputValues.size());
}
checkArgument(inputSize.getAsInt() == inputValues.size(), "Expected input size: %s, but found: %s", inputSize.getAsInt(), inputValues.size());
testCaseInputs.add(new TestCaseInput(columnName, columnType, inputValues));
return this;
}

public CastDataTypeTest execute(QueryRunner queryRunner, String tableNamePrefix)
{
checkState(!testCaseInputs.isEmpty(), "No test case inputs");
List<String> columnsWithType = new ArrayList<>();
for (TestCaseInput input : testCaseInputs) {
columnsWithType.add(input.columnName() + " " + input.columnType());
}
String tableDefinition = "(" + String.join(", ", columnsWithType) + ")";

if (inputSize.isPresent()) {
List<String> rowToInsert = new ArrayList<>();
for (int i = 0; i < inputSize.getAsInt(); i++) {
List<String> row = new ArrayList<>();
for (TestCaseInput input : testCaseInputs) {
row.add(String.valueOf(input.inputValues().get(i)));
}
rowToInsert.add(String.join(", ", row));
}
testTable = new TestTable(queryRunner::execute, tableNamePrefix, tableDefinition, rowToInsert);
}
else {
testTable = new TestTable(queryRunner::execute, tableNamePrefix, tableDefinition);
}
return this;
}

public String tableName()
{
return testTable.getName();
}

public record TestCaseInput(String columnName, String columnType, List<Object> inputValues)
{
public TestCaseInput
{
requireNonNull(columnName, "columnName is null");
requireNonNull(columnType, "columnType is null");
requireNonNull(inputValues, "inputValues is null");
}
}

public record CastTestCase(String sourceColumn, String castType, Optional<String> targetColumn)
{
public CastTestCase
{
requireNonNull(sourceColumn, "sourceColumn is null");
requireNonNull(castType, "castType is null");
requireNonNull(targetColumn, "targetColumn is null");
}

public CastTestCase(String sourceColumn, String castType)
{
this(sourceColumn, castType, Optional.empty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.base.projection.ProjectFunctionRewriter;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.BooleanWriteFunction;
Expand Down Expand Up @@ -165,10 +167,10 @@ public class OracleClient
private static final int MAX_BYTES_PER_CHAR = 4;

private static final int ORACLE_VARCHAR2_MAX_BYTES = 4000;
private static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int ORACLE_CHAR_MAX_BYTES = 2000;
private static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127;

Expand Down Expand Up @@ -215,6 +217,7 @@ public class OracleClient

private final boolean synonymsEnabled;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final ProjectFunctionRewriter<JdbcExpression, ParameterizedExpression> projectFunctionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;

@Inject
Expand Down Expand Up @@ -242,6 +245,12 @@ public OracleClient(
.add(new RewriteStringComparison())
.build();

this.projectFunctionRewriter = new ProjectFunctionRewriter<>(
this.connectorExpressionRewriter,
ImmutableSet.<ProjectFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
.add(new RewriteCast(this::toTargetType))
.build());

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
connectorExpressionRewriter,
Expand Down Expand Up @@ -553,6 +562,12 @@ public Optional<ParameterizedExpression> convertPredicate(ConnectorSession sessi
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}

@Override
public Optional<JdbcExpression> convertProjection(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return projectFunctionRewriter.rewrite(session, expression, assignments);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(OracleTypes.NUMBER, Optional.of("NUMBER"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down Expand Up @@ -793,6 +808,11 @@ private SliceWriteFunction oracleCharWriteFunction()
return SliceWriteFunction.of(Types.NCHAR, (statement, index, value) -> statement.unwrap(OraclePreparedStatement.class).setFixedCHAR(index, value.toStringUtf8()));
}

private String toTargetType(ConnectorSession session, Type type)
{
return toWriteMapping(session, type).getDataType();
}

@Override
public WriteMapping toWriteMapping(ConnectorSession session, Type type)
{
Expand Down
Loading

0 comments on commit b44eb71

Please sign in to comment.