Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cast projection pushdown in oracle #22728

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
krvikash marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.Type;

import java.util.Optional;
import java.util.function.BiFunction;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static java.util.Objects.requireNonNull;

public abstract class AbstractRewriteCast
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression>
{
private static final Capture<Variable> VALUE = newCapture();

private final BiFunction<ConnectorSession, Type, String> jdbcTypeProvider;

protected abstract Optional<JdbcTypeHandle> toJdbcTypeHandle(JdbcTypeHandle sourceType, Type targetType);

public AbstractRewriteCast(BiFunction<ConnectorSession, Type, String> jdbcTypeProvider)
{
this.jdbcTypeProvider = requireNonNull(jdbcTypeProvider, "jdbcTypeProvider is null");
}

protected String buildCast(@SuppressWarnings("unused") Type sourceType, @SuppressWarnings("unused") Type targetType, String expression, String castType)
{
return "CAST(%s AS %s)".formatted(expression, castType);
}

@Override
public Pattern<Call> getPattern()
{
return call()
.with(functionName().equalTo(CAST_FUNCTION_NAME))
.with(argumentCount().equalTo(1))
.with(argument(0).matching(variable().capturedAs(VALUE)));
}

@Override
public Optional<JdbcExpression> rewrite(ConnectorTableHandle handle, ConnectorExpression castExpression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
Variable variable = captures.get(VALUE);
JdbcTypeHandle sourceTypeJdbcHandle = ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle();
Type targetType = castExpression.getType();
Optional<JdbcTypeHandle> targetJdbcTypeHandle = toJdbcTypeHandle(sourceTypeJdbcHandle, targetType);

if (targetJdbcTypeHandle.isEmpty()) {
return Optional.empty();
}

Optional<ParameterizedExpression> value = context.rewriteExpression(variable);
if (value.isEmpty()) {
return Optional.empty();
}

Type sourceType = variable.getType();
String castType = jdbcTypeProvider.apply(context.getSession(), targetType);

return Optional.of(new JdbcExpression(
buildCast(sourceType, targetType, value.get().expression(), castType),
value.get().parameters(),
targetJdbcTypeHandle.get()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.sql.SqlExecutor;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public abstract class BaseJdbcCastPushdownTest
extends AbstractTestQueryFramework
{
protected abstract String leftTable();

protected abstract String rightTable();

protected abstract SqlExecutor onRemoteDatabase();

protected abstract List<CastTestCase> supportedCastTypePushdown();

protected abstract List<CastTestCase> unsupportedCastTypePushdown();

protected abstract List<CastTestCase> failCast();

@Test
public void testProjectionPushdownWithCast()
{
for (CastTestCase testCase : supportedCastTypePushdown()) {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
.isFullyPushedDown();
}

for (CastTestCase testCase : unsupportedCastTypePushdown()) {
assertThat(query("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
.isNotFullyPushedDown(ProjectNode.class);
}
}

@Test
public void testJoinPushdownWithCast()
{
for (CastTestCase testCase : supportedCastTypePushdown()) {
assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow())))
.isFullyPushedDown();
}

for (CastTestCase testCase : unsupportedCastTypePushdown()) {
assertThat(query("SELECT l.id FROM %s l JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn().orElseThrow())))
.joinIsNotFullyPushedDown();
}
}

@Test
public void testCastFails()
{
for (CastTestCase testCase : failCast()) {
assertThatThrownBy(() -> getQueryRunner().execute("SELECT CAST(%s AS %s) FROM %s".formatted(testCase.sourceColumn(), testCase.castType(), leftTable())))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this test fails, it doesn't say which condition has failed exactly. Do you think we could improve on that?
Took a quick look - as() or withFailMessage() does not seem to work, maybe custom satifies() would.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SemionPar, I will take this as a follow up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krvikash Sure thing!

One more thought that crossed my mind: if cast fails at runtime with different message when pushed down, say:

      Trino: Cannot cast DECIMAL(19, 2) '99999999999999999.99' to DECIMAL(10, 2)
      Teradata: [Error 2616] [SQLState 22003] Numeric overflow occurred during computation.

It would be good to be able to use testCastFails to test such cases, don't you think? We might want to add message customization to CastTestCase (or introduce FailCastTestCase) - consider this another improvement idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SemionPar Here is the refactoring PR #23737

.hasMessageMatching("(.*)Cannot cast (.*) to (.*)");
Comment on lines +75 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it thrown by Trino or by Oracle or by both ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By Trino.

}
}

public record CastTestCase(String sourceColumn, String castType, Optional<String> targetColumn)
{
public CastTestCase(String sourceColumn, String castType)
{
this(sourceColumn, castType, Optional.empty());
}

public CastTestCase
{
requireNonNull(sourceColumn, "sourceColumn is null");
requireNonNull(castType, "castType is null");
requireNonNull(targetColumn, "targetColumn is null");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.testing.sql.SqlExecutor;
import io.trino.testing.sql.TemporaryRelation;
import io.trino.testing.sql.TestTable;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.testing.Closeables.closeAll;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public final class CastDataTypeTestTable
implements TemporaryRelation
{
private final List<TestCaseInput> testCaseRows = new ArrayList<>();
private final int rowCount;

private TestTable testTable;

private CastDataTypeTestTable(int rowCount)
{
this.rowCount = rowCount;
}

public static CastDataTypeTestTable create(int rowCount)
{
return new CastDataTypeTestTable(rowCount);
}

public CastDataTypeTestTable addColumn(String columnName, String columnType, List<Object> inputValues)
{
checkArgument(rowCount == inputValues.size(), "Expected input size: %s, but found: %s", rowCount, inputValues.size());
testCaseRows.add(new TestCaseInput(columnName, columnType, inputValues));
return this;
}

public CastDataTypeTestTable execute(SqlExecutor sqlExecutor, String tableNamePrefix)
{
checkState(!testCaseRows.isEmpty(), "No test case rows");
List<String> columnsWithType = new ArrayList<>();
for (TestCaseInput input : testCaseRows) {
columnsWithType.add("%s %s".formatted(input.columnName(), input.columnType()));
}
String tableDefinition = columnsWithType.stream()
.collect(joining(", ", "(", ")"));

List<String> rowsToInsert = IntStream.range(0, rowCount)
.mapToObj(rowId -> testCaseRows.stream()
.map(TestCaseInput::inputValues)
.map(rows -> rows.get(rowId))
.map(String::valueOf)
.collect(joining(",")))
.collect(toImmutableList());
testTable = new TestTable(sqlExecutor, tableNamePrefix, tableDefinition, rowsToInsert);
return this;
}

@Override
public String getName()
{
return testTable.getName();
}

@Override
public void close()
{
try {
closeAll(testTable);
}
catch (Exception e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make method throws

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base class close method does not throws. So can't make this method throws. May be as a followup?

throw new RuntimeException(e);
}
}

public record TestCaseInput(String columnName, String columnType, List<Object> inputValues)
{
public TestCaseInput
{
requireNonNull(columnName, "columnName is null");
requireNonNull(columnType, "columnType is null");
inputValues = new ArrayList<>(requireNonNull(inputValues, "inputValues is null"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.base.projection.ProjectFunctionRewriter;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.BooleanWriteFunction;
Expand Down Expand Up @@ -165,10 +167,10 @@ public class OracleClient
private static final int MAX_BYTES_PER_CHAR = 4;

private static final int ORACLE_VARCHAR2_MAX_BYTES = 4000;
private static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_VARCHAR2_MAX_CHARS = ORACLE_VARCHAR2_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int ORACLE_CHAR_MAX_BYTES = 2000;
private static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;
public static final int ORACLE_CHAR_MAX_CHARS = ORACLE_CHAR_MAX_BYTES / MAX_BYTES_PER_CHAR;

private static final int PRECISION_OF_UNSPECIFIED_NUMBER = 127;

Expand Down Expand Up @@ -215,6 +217,7 @@ public class OracleClient

private final boolean synonymsEnabled;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final ProjectFunctionRewriter<JdbcExpression, ParameterizedExpression> projectFunctionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;
private final Optional<Integer> fetchSize;

Expand Down Expand Up @@ -243,6 +246,12 @@ public OracleClient(
.add(new RewriteStringComparison())
.build();

this.projectFunctionRewriter = new ProjectFunctionRewriter<>(
this.connectorExpressionRewriter,
ImmutableSet.<ProjectFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
krvikash marked this conversation as resolved.
Show resolved Hide resolved
.add(new RewriteCast((session, type) -> toWriteMapping(session, type).getDataType()))
.build());

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
connectorExpressionRewriter,
Expand Down Expand Up @@ -559,6 +568,12 @@ public Optional<ParameterizedExpression> convertPredicate(ConnectorSession sessi
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}

@Override
public Optional<JdbcExpression> convertProjection(ConnectorSession session, JdbcTableHandle handle, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return projectFunctionRewriter.rewrite(session, handle, expression, assignments);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(OracleTypes.NUMBER, Optional.of("NUMBER"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
Expand Down
Loading