-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support cast projection pushdown in oracle
- Loading branch information
Showing
7 changed files
with
759 additions
and
2 deletions.
There are no files selected for viewing
87 changes: 87 additions & 0 deletions
87
...in/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/AbstractRewriteCast.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc.expression; | ||
|
||
import io.trino.matching.Capture; | ||
import io.trino.matching.Captures; | ||
import io.trino.matching.Pattern; | ||
import io.trino.plugin.base.projection.ProjectFunctionRule; | ||
import io.trino.plugin.jdbc.JdbcExpression; | ||
import io.trino.plugin.jdbc.JdbcTypeHandle; | ||
import io.trino.spi.connector.ConnectorSession; | ||
import io.trino.spi.expression.Call; | ||
import io.trino.spi.expression.ConnectorExpression; | ||
import io.trino.spi.type.Type; | ||
|
||
import java.util.Optional; | ||
import java.util.function.BiFunction; | ||
|
||
import static io.trino.matching.Capture.newCapture; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; | ||
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; | ||
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; | ||
import static java.util.Objects.requireNonNull; | ||
|
||
public abstract class AbstractRewriteCast | ||
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression> | ||
{ | ||
private static final Capture<ConnectorExpression> VALUE = newCapture(); | ||
private final BiFunction<ConnectorSession, Type, String> toTargetType; | ||
|
||
protected abstract Optional<JdbcTypeHandle> toJdbcTypeHandle(Type sourceType, Type targetType); | ||
protected abstract String buildCast(String expression, String castType); | ||
|
||
public AbstractRewriteCast(BiFunction<ConnectorSession, Type, String> toTargetType) | ||
{ | ||
this.toTargetType = requireNonNull(toTargetType, "toTargetType is null"); | ||
} | ||
|
||
@Override | ||
public Pattern<Call> getPattern() | ||
{ | ||
return call() | ||
.with(functionName().equalTo(CAST_FUNCTION_NAME)) | ||
.with(argumentCount().equalTo(1)) | ||
.with(argument(0).matching(expression().capturedAs(VALUE))); | ||
} | ||
|
||
@Override | ||
public Optional<JdbcExpression> rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext<ParameterizedExpression> context) | ||
{ | ||
Call call = (Call) projectionExpression; | ||
Type targetType = call.getType(); | ||
ConnectorExpression capturedValue = captures.get(VALUE); | ||
Type sourceType = capturedValue.getType(); | ||
|
||
Optional<ParameterizedExpression> value = context.rewriteExpression(capturedValue); | ||
if (value.isEmpty()) { | ||
// if argument is a call chain that can't be rewritten, then we can't push it down | ||
return Optional.empty(); | ||
} | ||
Optional<JdbcTypeHandle> targetTypeHandle = toJdbcTypeHandle(sourceType, targetType); | ||
if (targetTypeHandle.isEmpty()) { | ||
return Optional.empty(); | ||
} | ||
|
||
String castType = toTargetType.apply(context.getSession(), targetType); | ||
|
||
return Optional.of(new JdbcExpression( | ||
buildCast(value.get().expression(), castType), | ||
value.get().parameters(), | ||
targetTypeHandle.get())); | ||
} | ||
} |
128 changes: 128 additions & 0 deletions
128
plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseCastPushdownTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc; | ||
|
||
import io.trino.Session; | ||
import io.trino.sql.planner.plan.ProjectNode; | ||
import io.trino.sql.query.QueryAssertions; | ||
import io.trino.testing.QueryRunner; | ||
import org.junit.jupiter.api.AfterAll; | ||
import org.junit.jupiter.api.BeforeAll; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.TestInstance; | ||
|
||
import java.util.List; | ||
|
||
import static io.trino.plugin.jdbc.CastDataTypeTest.CastTestCase; | ||
import static java.util.Objects.requireNonNull; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.api.Assertions.assertThatThrownBy; | ||
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; | ||
|
||
@TestInstance(PER_CLASS) | ||
public abstract class BaseCastPushdownTest | ||
{ | ||
protected final Session session; | ||
protected final QueryRunner queryRunner; | ||
protected final QueryAssertions assertions; | ||
protected CastDataTypeTest table1; | ||
protected CastDataTypeTest table2; | ||
|
||
public BaseCastPushdownTest(Session session, QueryRunner queryRunner) | ||
{ | ||
this.session = requireNonNull(session, "session is null"); | ||
this.queryRunner = requireNonNull(queryRunner, "queryRunner is null"); | ||
this.assertions = new QueryAssertions(queryRunner); | ||
} | ||
|
||
protected abstract void setupTable(); | ||
protected abstract List<CastTestCase> supportedCastTypePushdown(); | ||
protected abstract List<CastTestCase> unsupportedCastTypePushdown(); | ||
protected abstract List<CastTestCase> failCast(); | ||
|
||
@BeforeAll | ||
public void setup() | ||
{ | ||
setupTable(); | ||
} | ||
|
||
@AfterAll | ||
public void cleanup() | ||
{ | ||
dropTable(table1); | ||
dropTable(table2); | ||
} | ||
|
||
private void dropTable(CastDataTypeTest table) | ||
{ | ||
if (table2 != null) { | ||
queryRunner.execute("DROP TABLE IF EXISTS " + table.tableName()); | ||
} | ||
} | ||
|
||
@Test | ||
public void testCastProjectionPushdown() | ||
{ | ||
for (CastDataTypeTest.CastTestCase testCase : supportedCastTypePushdown()) { | ||
assertThat(assertions.query(session, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), table1.tableName()))) | ||
.isFullyPushedDown(); | ||
} | ||
|
||
for (CastTestCase testCase : unsupportedCastTypePushdown()) { | ||
assertThat(assertions.query(session, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), table1.tableName()))) | ||
.isNotFullyPushedDown(ProjectNode.class); | ||
} | ||
} | ||
|
||
@Test | ||
public void testCastProjectionFails() | ||
{ | ||
for (CastTestCase testCase : failCast()) { | ||
assertThatThrownBy(() -> queryRunner.execute(session, "SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), table1.tableName()))) | ||
.hasMessageMatching("(.*)Cannot cast (.*) to (.*)"); | ||
} | ||
} | ||
|
||
@Test | ||
public void testJoinPushdownWithCast() | ||
{ | ||
// combination of supported types | ||
for (CastTestCase testCase : supportedCastTypePushdown()) { | ||
assertJoinFullyPushedDown(session, table1.tableName(), table2.tableName(), "JOIN", "CAST(l.%s AS %s) = r.%s".formatted(testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow())); | ||
} | ||
|
||
// combination of unsupported types | ||
for (CastTestCase testCase : unsupportedCastTypePushdown()) { | ||
assertJoinNotFullyPushedDown(session, table1.tableName(), table2.tableName(), "CAST(l.%s AS %s) = r.%s".formatted(testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow())); | ||
} | ||
} | ||
|
||
protected void assertJoinFullyPushedDown(Session session, String leftTable, String rightTable, String joinType, String joinCondition) | ||
{ | ||
assertJoin(session, leftTable, rightTable, joinType, joinCondition) | ||
.isFullyPushedDown(); | ||
} | ||
|
||
protected void assertJoinNotFullyPushedDown(Session session, String leftTable, String rightTable, String joinCondition) | ||
{ | ||
assertJoin(session, leftTable, rightTable, "JOIN", joinCondition) | ||
.joinIsNotFullyPushedDown(); | ||
} | ||
|
||
protected QueryAssertions.QueryAssert assertJoin(Session session, String leftTable, String rightTable, String joinType, String joinCondition) | ||
{ | ||
System.out.println("SELECT l.id FROM %s l %s %s r ON %s".formatted(leftTable, joinType, rightTable, joinCondition)); | ||
return assertThat(assertions.query(session, "SELECT l.id FROM %s l %s %s r ON %s".formatted(leftTable, joinType, rightTable, joinCondition))); | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/CastDataTypeTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.plugin.jdbc; | ||
|
||
import io.trino.testing.QueryRunner; | ||
import io.trino.testing.sql.TestTable; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Optional; | ||
import java.util.OptionalInt; | ||
|
||
import static com.google.common.base.Preconditions.checkArgument; | ||
import static com.google.common.base.Preconditions.checkState; | ||
import static java.util.Objects.requireNonNull; | ||
|
||
public class CastDataTypeTest | ||
{ | ||
private final List<TestCaseInput> testCaseInputs = new ArrayList<>(); | ||
|
||
private OptionalInt inputSize = OptionalInt.empty(); | ||
private TestTable testTable; | ||
|
||
private CastDataTypeTest() {} | ||
|
||
public static CastDataTypeTest create() | ||
{ | ||
return new CastDataTypeTest(); | ||
} | ||
|
||
public CastDataTypeTest addColumn(String columnName, String columnType, List<Object> inputValues) | ||
{ | ||
if (inputSize.isEmpty()) { | ||
inputSize = OptionalInt.of(inputValues.size()); | ||
} | ||
checkArgument(inputSize.getAsInt() == inputValues.size(), "Expected input size: %s, but found: %s", inputSize.getAsInt(), inputValues.size()); | ||
testCaseInputs.add(new TestCaseInput(columnName, columnType, inputValues)); | ||
return this; | ||
} | ||
|
||
public CastDataTypeTest execute(QueryRunner queryRunner, String tableNamePrefix) | ||
{ | ||
checkState(!testCaseInputs.isEmpty(), "No test case inputs"); | ||
List<String> columnsWithType = new ArrayList<>(); | ||
for (TestCaseInput input : testCaseInputs) { | ||
columnsWithType.add(input.columnName() + " " + input.columnType()); | ||
} | ||
String tableDefinition = "(" + String.join(", ", columnsWithType) + ")"; | ||
|
||
if (inputSize.isPresent()) { | ||
List<String> rowToInsert = new ArrayList<>(); | ||
for (int i = 0; i < inputSize.getAsInt(); i++) { | ||
List<String> row = new ArrayList<>(); | ||
for (TestCaseInput input : testCaseInputs) { | ||
row.add(String.valueOf(input.inputValues().get(i))); | ||
} | ||
rowToInsert.add(String.join(", ", row)); | ||
} | ||
testTable = new TestTable(queryRunner::execute, tableNamePrefix, tableDefinition, rowToInsert); | ||
} | ||
else { | ||
testTable = new TestTable(queryRunner::execute, tableNamePrefix, tableDefinition); | ||
} | ||
return this; | ||
} | ||
|
||
public String tableName() | ||
{ | ||
return testTable.getName(); | ||
} | ||
|
||
public record TestCaseInput(String columnName, String columnType, List<Object> inputValues) | ||
{ | ||
public TestCaseInput | ||
{ | ||
requireNonNull(columnName, "columnName is null"); | ||
requireNonNull(columnType, "columnType is null"); | ||
requireNonNull(inputValues, "inputValues is null"); | ||
} | ||
} | ||
|
||
public record CastTestCase(String sourceColumn, String castType, Optional<String> targetColumn) | ||
{ | ||
public CastTestCase | ||
{ | ||
requireNonNull(sourceColumn, "sourceColumn is null"); | ||
requireNonNull(castType, "castType is null"); | ||
requireNonNull(targetColumn, "targetColumn is null"); | ||
} | ||
|
||
public CastTestCase(String sourceColumn, String castType) | ||
{ | ||
this(sourceColumn, castType, Optional.empty()); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.