-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support cast projection pushdown in oracle
- Loading branch information
Showing
8 changed files
with
949 additions
and
2 deletions.
There are no files selected for viewing
95 changes: 95 additions & 0 deletions
95
...in/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/AbstractRewriteCast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc.expression; | ||
|
||
import io.trino.matching.Capture; | ||
import io.trino.matching.Captures; | ||
import io.trino.matching.Pattern; | ||
import io.trino.plugin.base.projection.ProjectFunctionRule; | ||
import io.trino.plugin.jdbc.JdbcColumnHandle; | ||
import io.trino.plugin.jdbc.JdbcExpression; | ||
import io.trino.plugin.jdbc.JdbcTypeHandle; | ||
import io.trino.spi.connector.ConnectorSession; | ||
import io.trino.spi.connector.ConnectorTableHandle; | ||
import io.trino.spi.expression.Call; | ||
import io.trino.spi.expression.ConnectorExpression; | ||
import io.trino.spi.expression.Variable; | ||
import io.trino.spi.type.Type; | ||
|
||
import java.util.Optional; | ||
import java.util.function.BiFunction; | ||
|
||
import static io.trino.matching.Capture.newCapture; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; | ||
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; | ||
import static java.util.Objects.requireNonNull; | ||
|
||
public abstract class AbstractRewriteCast | ||
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression> | ||
{ | ||
private static final Capture<Variable> VALUE = newCapture(); | ||
|
||
private final BiFunction<ConnectorSession, Type, String> toTargetType; | ||
|
||
protected abstract Optional<JdbcTypeHandle> toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType); | ||
|
||
protected String buildCast(@SuppressWarnings("unused") Type sourceType, @SuppressWarnings("unused") Type targetType, String expression, String castType) | ||
{ | ||
return "CAST(%s AS %s)".formatted(expression, castType); | ||
} | ||
|
||
public AbstractRewriteCast(BiFunction<ConnectorSession, Type, String> toTargetType) | ||
{ | ||
this.toTargetType = requireNonNull(toTargetType, "toTargetType is null"); | ||
} | ||
|
||
@Override | ||
public Pattern<Call> getPattern() | ||
{ | ||
return call() | ||
.with(functionName().equalTo(CAST_FUNCTION_NAME)) | ||
.with(argumentCount().equalTo(1)) | ||
.with(argument(0).matching(variable().capturedAs(VALUE))); | ||
} | ||
|
||
@Override | ||
public Optional<JdbcExpression> rewrite(ConnectorTableHandle handle, ConnectorExpression castExpression, Captures captures, RewriteContext<ParameterizedExpression> context) | ||
{ | ||
Variable variable = captures.get(VALUE); | ||
JdbcTypeHandle sourceTypeJdbcHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle(); | ||
Type targetType = castExpression.getType(); | ||
Optional<JdbcTypeHandle> targetJdbcTypeHandle = toJdbcTypeHandle(sourceTypeJdbcHandle, targetType); | ||
|
||
if (targetJdbcTypeHandle.isEmpty()) { | ||
return Optional.empty(); | ||
} | ||
|
||
Optional<ParameterizedExpression> value = context.rewriteExpression(variable); | ||
if (value.isEmpty()) { | ||
return Optional.empty(); | ||
} | ||
|
||
Type sourceType = variable.getType(); | ||
String castType = toTargetType.apply(context.getSession(), targetType); | ||
|
||
return Optional.of(new JdbcExpression( | ||
buildCast(sourceType, targetType, value.get().expression(), castType), | ||
value.get().parameters(), | ||
targetJdbcTypeHandle.get())); | ||
} | ||
} |
94 changes: 94 additions & 0 deletions
94
plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcCastPushdownTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc; | ||
|
||
import io.trino.sql.planner.plan.ProjectNode; | ||
import io.trino.testing.AbstractTestQueryFramework; | ||
import io.trino.testing.sql.SqlExecutor; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.List; | ||
import java.util.Optional; | ||
|
||
import static java.util.Objects.requireNonNull; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.api.Assertions.assertThatThrownBy; | ||
|
||
public abstract class BaseJdbcCastPushdownTest | ||
extends AbstractTestQueryFramework | ||
{ | ||
protected abstract String leftTable(); | ||
|
||
protected abstract String rightTable(); | ||
|
||
protected abstract SqlExecutor onRemoteDatabase(); | ||
|
||
protected abstract List<CastTestCase> supportedCastTypePushdown(); | ||
|
||
protected abstract List<CastTestCase> unsupportedCastTypePushdown(); | ||
|
||
protected abstract List<CastTestCase> failCast(); | ||
|
||
@Test | ||
public void testProjectionPushdownWithCast() | ||
{ | ||
for (CastTestCase testCase : supportedCastTypePushdown()) { | ||
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) | ||
.isFullyPushedDown(); | ||
} | ||
|
||
for (CastTestCase testCase : unsupportedCastTypePushdown()) { | ||
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) | ||
.isNotFullyPushedDown(ProjectNode.class); | ||
} | ||
} | ||
|
||
@Test | ||
public void testJoinPushdownWithCast() | ||
{ | ||
for (CastTestCase testCase : supportedCastTypePushdown()) { | ||
assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) | ||
.isFullyPushedDown(); | ||
} | ||
|
||
for (CastTestCase testCase : unsupportedCastTypePushdown()) { | ||
assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow()))) | ||
.joinIsNotFullyPushedDown(); | ||
} | ||
} | ||
|
||
@Test | ||
public void testCastFails() | ||
{ | ||
for (CastTestCase testCase : failCast()) { | ||
assertThatThrownBy(() -> getQueryRunner().execute("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable()))) | ||
.hasMessageMatching("(.*)Cannot cast (.*) to (.*)"); | ||
} | ||
} | ||
|
||
public record CastTestCase(String sourceColumn, String castType, Optional<String> targetColumn) | ||
{ | ||
public CastTestCase(String sourceColumn, String castType) | ||
{ | ||
this(sourceColumn, castType, Optional.empty()); | ||
} | ||
|
||
public CastTestCase | ||
{ | ||
requireNonNull(sourceColumn, "sourceColumn is null"); | ||
requireNonNull(castType, "castType is null"); | ||
requireNonNull(targetColumn, "targetColumn is null"); | ||
} | ||
} | ||
} |
103 changes: 103 additions & 0 deletions
103
plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/CastDataTypeTestTable.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc; | ||
|
||
import io.trino.testing.sql.SqlExecutor; | ||
import io.trino.testing.sql.TemporaryRelation; | ||
import io.trino.testing.sql.TestTable; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.stream.IntStream; | ||
|
||
import static com.google.common.base.Preconditions.checkArgument; | ||
import static com.google.common.base.Preconditions.checkState; | ||
import static com.google.common.collect.ImmutableList.toImmutableList; | ||
import static io.airlift.testing.Closeables.closeAll; | ||
import static java.util.Objects.requireNonNull; | ||
import static java.util.stream.Collectors.joining; | ||
|
||
public final class CastDataTypeTestTable | ||
implements TemporaryRelation | ||
{ | ||
private final List<TestCaseInput> testCaseRows = new ArrayList<>(); | ||
private final int rowCount; | ||
|
||
private TestTable testTable; | ||
|
||
private CastDataTypeTestTable(int rowCount) | ||
{ | ||
this.rowCount = rowCount; | ||
} | ||
|
||
public static CastDataTypeTestTable create(int rowCount) | ||
{ | ||
return new CastDataTypeTestTable(rowCount); | ||
} | ||
|
||
public CastDataTypeTestTable addColumn(String columnName, String columnType, List<Object> inputValues) | ||
{ | ||
checkArgument(rowCount == inputValues.size(), "Expected input size: %s, but found: %s", rowCount, inputValues.size()); | ||
testCaseRows.add(new TestCaseInput(columnName, columnType, inputValues)); | ||
return this; | ||
} | ||
|
||
public CastDataTypeTestTable execute(SqlExecutor sqlExecutor, String tableNamePrefix) | ||
{ | ||
checkState(!testCaseRows.isEmpty(), "No test case rows"); | ||
List<String> columnsWithType = new ArrayList<>(); | ||
for (TestCaseInput input : testCaseRows) { | ||
columnsWithType.add("%s %s".formatted(input.columnName(), input.columnType())); | ||
} | ||
String tableDefinition = columnsWithType.stream() | ||
.collect(joining(", ", "(", ")")); | ||
|
||
List<String> rowsToInsert = IntStream.range(0, rowCount) | ||
.mapToObj(rowId -> testCaseRows.stream() | ||
.map(TestCaseInput::inputValues) | ||
.map(rows -> rows.get(rowId)) | ||
.map(String::valueOf) | ||
.collect(joining(","))) | ||
.collect(toImmutableList()); | ||
testTable = new TestTable(sqlExecutor, tableNamePrefix, tableDefinition, rowsToInsert); | ||
return this; | ||
} | ||
|
||
@Override | ||
public String getName() | ||
{ | ||
return testTable.getName(); | ||
} | ||
|
||
@Override | ||
public void close() | ||
{ | ||
try { | ||
closeAll(testTable); | ||
} | ||
catch (Exception e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
public record TestCaseInput(String columnName, String columnType, List<Object> inputValues) | ||
{ | ||
public TestCaseInput | ||
{ | ||
requireNonNull(columnName, "columnName is null"); | ||
requireNonNull(columnType, "columnType is null"); | ||
inputValues = new ArrayList<>(requireNonNull(inputValues, "inputValues is null")); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.