Skip to content

Commit

Permalink
Support cast projection pushdown in oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
krvikash committed Sep 5, 2024
1 parent 8455ce6 commit 93b238f
Show file tree
Hide file tree
Showing 8 changed files with 949 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.Type;

import java.util.Optional;
import java.util.function.BiFunction;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static java.util.Objects.requireNonNull;

public abstract class AbstractRewriteCast
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression>
{
private static final Capture<Variable> VALUE = newCapture();

private final BiFunction<ConnectorSession, Type, String> toTargetType;

protected abstract Optional<JdbcTypeHandle> toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType);

protected String buildCast(@SuppressWarnings("unused") Type sourceType, @SuppressWarnings("unused") Type targetType, String expression, String castType)
{
return "CAST(%s AS %s)".formatted(expression, castType);
}

public AbstractRewriteCast(BiFunction<ConnectorSession, Type, String> toTargetType)
{
this.toTargetType = requireNonNull(toTargetType, "toTargetType is null");
}

@Override
public Pattern<Call> getPattern()
{
return call()
.with(functionName().equalTo(CAST_FUNCTION_NAME))
.with(argumentCount().equalTo(1))
.with(argument(0).matching(variable().capturedAs(VALUE)));
}

@Override
public Optional<JdbcExpression> rewrite(ConnectorTableHandle handle, ConnectorExpression castExpression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
Variable variable = captures.get(VALUE);
JdbcTypeHandle sourceTypeJdbcHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle();
Type targetType = castExpression.getType();
Optional<JdbcTypeHandle> targetJdbcTypeHandle = toJdbcTypeHandle(sourceTypeJdbcHandle, targetType);

if (targetJdbcTypeHandle.isEmpty()) {
return Optional.empty();
}

Optional<ParameterizedExpression> value = context.rewriteExpression(variable);
if (value.isEmpty()) {
return Optional.empty();
}

Type sourceType = variable.getType();
String castType = toTargetType.apply(context.getSession(), targetType);

return Optional.of(new JdbcExpression(
buildCast(sourceType, targetType, value.get().expression(), castType),
value.get().parameters(),
targetJdbcTypeHandle.get()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.sql.SqlExecutor;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class BaseJdbcCastPushdownTest
extends AbstractTestQueryFramework
{
protected abstract String leftTable();

protected abstract String rightTable();

protected abstract SqlExecutor onRemoteDatabase();

protected abstract List<CastTestCase> supportedCastTypePushdown();

protected abstract List<CastTestCase> unsupportedCastTypePushdown();

protected abstract List<CastTestCase> failCast();

@Test
public void testProjectionPushdownWithCast()
{
for (CastTestCase testCase : supportedCastTypePushdown()) {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
.isFullyPushedDown();
}

for (CastTestCase testCase : unsupportedCastTypePushdown()) {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
.isNotFullyPushedDown(ProjectNode.class);
}
}

@Test
public void testJoinPushdownWithCast()
{
for (CastTestCase testCase : supportedCastTypePushdown()) {
assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow())))
.isFullyPushedDown();
}

for (CastTestCase testCase : unsupportedCastTypePushdown()) {
assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow())))
.joinIsNotFullyPushedDown();
}
}

@Test
public void testCastFails()
{
for (CastTestCase testCase : failCast()) {
assertThatThrownBy(() -> getQueryRunner().execute("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
.hasMessageMatching("(.*)Cannot cast (.*) to (.*)");
}
}

public record CastTestCase(String sourceColumn, String castType, Optional<String> targetColumn)
{
public CastTestCase(String sourceColumn, String castType)
{
this(sourceColumn, castType, Optional.empty());
}

public CastTestCase
{
requireNonNull(sourceColumn, "sourceColumn is null");
requireNonNull(castType, "castType is null");
requireNonNull(targetColumn, "targetColumn is null");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.testing.sql.SqlExecutor;
import io.trino.testing.sql.TemporaryRelation;
import io.trino.testing.sql.TestTable;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.testing.Closeables.closeAll;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public final class CastDataTypeTestTable
implements TemporaryRelation
{
private final List<TestCaseInput> testCaseRows = new ArrayList<>();
private final int rowCount;

private TestTable testTable;

private CastDataTypeTestTable(int rowCount)
{
this.rowCount = rowCount;
}

public static CastDataTypeTestTable create(int rowCount)
{
return new CastDataTypeTestTable(rowCount);
}

public CastDataTypeTestTable addColumn(String columnName, String columnType, List<Object> inputValues)
{
checkArgument(rowCount == inputValues.size(), "Expected input size: %s, but found: %s", rowCount, inputValues.size());
testCaseRows.add(new TestCaseInput(columnName, columnType, inputValues));
return this;
}

public CastDataTypeTestTable execute(SqlExecutor sqlExecutor, String tableNamePrefix)
{
checkState(!testCaseRows.isEmpty(), "No test case rows");
List<String> columnsWithType = new ArrayList<>();
for (TestCaseInput input : testCaseRows) {
columnsWithType.add("%s %s".formatted(input.columnName(), input.columnType()));
}
String tableDefinition = columnsWithType.stream()
.collect(joining(", ", "(", ")"));

List<String> rowsToInsert = IntStream.range(0, rowCount)
.mapToObj(rowId -> testCaseRows.stream()
.map(TestCaseInput::inputValues)
.map(rows -> rows.get(rowId))
.map(String::valueOf)
.collect(joining(",")))
.collect(toImmutableList());
testTable = new TestTable(sqlExecutor, tableNamePrefix, tableDefinition, rowsToInsert);
return this;
}

@Override
public String getName()
{
return testTable.getName();
}

@Override
public void close()
{
try {
closeAll(testTable);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

public record TestCaseInput(String columnName, String columnType, List<Object> inputValues)
{
public TestCaseInput
{
requireNonNull(columnName, "columnName is null");
requireNonNull(columnType, "columnType is null");
inputValues = new ArrayList<>(requireNonNull(inputValues, "inputValues is null"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.base.projection.ProjectFunctionRewriter;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.BooleanWriteFunction;
Expand Down Expand Up @@ -165,10 +167,10 @@ public class OracleClient
private static final int MAX_BYTES_PER_CHAR = 4;

private static final int ORACLE_VARCHAR2_MAX_BYTES = 4000;
private static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int ORACLE_CHAR_MAX_BYTES = 2000;
private static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127;

Expand Down Expand Up @@ -215,6 +217,7 @@ public class OracleClient

private final boolean synonymsEnabled;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final ProjectFunctionRewriter<JdbcExpression, ParameterizedExpression> projectFunctionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;
private final Optional<Integer> fetchSize;

Expand Down Expand Up @@ -243,6 +246,12 @@ public OracleClient(
.add(new RewriteStringComparison())
.build();

this.projectFunctionRewriter = new ProjectFunctionRewriter<>(
this.connectorExpressionRewriter,
ImmutableSet.<ProjectFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
.add(new RewriteCast((session, type) -> toWriteMapping(session, type).getDataType()))
.build());

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
connectorExpressionRewriter,
Expand Down Expand Up @@ -559,6 +568,12 @@ public Optional<ParameterizedExpression> convertPredicate(ConnectorSession sessi
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}

@Override
public Optional<JdbcExpression> convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return projectFunctionRewriter.rewrite(session, handle, expression, assignments);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(OracleTypes.NUMBER, Optional.of("NUMBER"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Loading

0 comments on commit 93b238f

Please sign in to comment.