Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support range predicate pushdown for string columns with collation in PostgreSQL connector #9746

Merged
merged 1 commit into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
import static io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping.AS_JSON;
import static io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping.DISABLED;
import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.getArrayMapping;
import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.isEnableStringPushdownWithCollate;
import static io.trino.plugin.postgresql.TypeUtils.arrayDepth;
import static io.trino.plugin.postgresql.TypeUtils.getArrayElementPgTypeName;
import static io.trino.plugin.postgresql.TypeUtils.getJdbcObjectArray;
Expand Down Expand Up @@ -234,7 +235,7 @@ public class PostgreSqlClient
private final List<String> tableTypes;
private final AggregateFunctionRewriter<JdbcExpression> aggregateFunctionRewriter;

private static final PredicatePushdownController POSTGRESQL_CHARACTER_PUSHDOWN = (session, domain) -> {
private static final PredicatePushdownController POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE = (session, domain) -> {
checkArgument(
domain.getType() instanceof VarcharType || domain.getType() instanceof CharType,
"This PredicatePushdownController can be used only for chars and varchars");
Expand Down Expand Up @@ -506,13 +507,22 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
}

case Types.CHAR:
if (isEnableStringPushdownWithCollate(session)) {
return Optional.of(charColumnMappingWithCollate(typeHandle.getRequiredColumnSize()));
}
return Optional.of(charColumnMapping(typeHandle.getRequiredColumnSize()));

case Types.VARCHAR:
if (!jdbcTypeName.equals("varchar")) {
// This can be e.g. an ENUM
if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) {
return Optional.of(typedVarcharColumnMappingWithCollate(jdbcTypeName));
}
return Optional.of(typedVarcharColumnMapping(jdbcTypeName));
}
if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) {
return Optional.of(varcharColumnMappingWithCollate(typeHandle.getRequiredColumnSize()));
}
return Optional.of(varcharColumnMapping(typeHandle.getRequiredColumnSize()));

case Types.BINARY:
Expand Down Expand Up @@ -749,14 +759,19 @@ private boolean isCollatable(JdbcColumnHandle column)
if (column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType) {
String jdbcTypeName = column.getJdbcTypeHandle().getJdbcTypeName()
.orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + column.getJdbcTypeHandle()));
// Only char (internally named bpchar)/varchar/text are the built-in collatable types
hashhar marked this conversation as resolved.
Show resolved Hide resolved
return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName);
return isCollatable(jdbcTypeName);
}

// non-textual types don't have the concept of collation
return false;
}

private boolean isCollatable(String jdbcTypeName)
{
// Only char (internally named bpchar)/varchar/text are the built-in collatable types
return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName);
}

@Override
public boolean isTopNGuaranteed(ConnectorSession session)
{
Expand Down Expand Up @@ -833,7 +848,20 @@ private static ColumnMapping charColumnMapping(int charLength)
charType,
charReadFunction(charType),
charWriteFunction(),
POSTGRESQL_CHARACTER_PUSHDOWN);
POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE);
}

private static ColumnMapping charColumnMappingWithCollate(int charLength)
{
if (charLength > CharType.MAX_LENGTH) {
return varcharColumnMappingWithCollate(charLength);
}
CharType charType = createCharType(charLength);
return ColumnMapping.sliceMapping(
charType,
charReadFunction(charType),
stringWriteFunctionWithCollate(),
FULL_PUSHDOWN);
}

private static ColumnMapping varcharColumnMapping(int varcharLength)
Expand All @@ -845,7 +873,38 @@ private static ColumnMapping varcharColumnMapping(int varcharLength)
varcharType,
varcharReadFunction(varcharType),
varcharWriteFunction(),
POSTGRESQL_CHARACTER_PUSHDOWN);
POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE);
}

private static ColumnMapping varcharColumnMappingWithCollate(int varcharLength)
{
VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH
? createVarcharType(varcharLength)
: createUnboundedVarcharType();
return ColumnMapping.sliceMapping(
varcharType,
varcharReadFunction(varcharType),
stringWriteFunctionWithCollate(),
FULL_PUSHDOWN);
}

private static SliceWriteFunction stringWriteFunctionWithCollate()
{
return new SliceWriteFunction()
{
@Override
public String getBindExpression()
{
return "? COLLATE \"C\"";
}

@Override
public void set(PreparedStatement statement, int index, Slice value)
throws SQLException
{
statement.setString(index, value.toStringUtf8());
}
};
}

private static ColumnMapping timeColumnMapping(int precision)
Expand Down Expand Up @@ -1161,12 +1220,45 @@ private static ColumnMapping typedVarcharColumnMapping(String jdbcTypeName)
return ColumnMapping.sliceMapping(
VARCHAR,
(resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)),
typedVarcharWriteFunction(jdbcTypeName));
typedVarcharWriteFunction(jdbcTypeName),
POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE);
}

private static ColumnMapping typedVarcharColumnMappingWithCollate(String jdbcTypeName)
{
return ColumnMapping.sliceMapping(
VARCHAR,
(resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)),
typedVarcharWriteFunctionWithCollate(jdbcTypeName),
FULL_PUSHDOWN);
}

private static SliceWriteFunction typedVarcharWriteFunction(String jdbcTypeName)
{
String bindExpression = format("CAST(? AS %s)", requireNonNull(jdbcTypeName, "jdbcTypeName is null"));

return new SliceWriteFunction()
{
@Override
public String getBindExpression()
{
return bindExpression;
}

@Override
public void set(PreparedStatement statement, int index, Slice value)
throws SQLException
{
statement.setString(index, value.toStringUtf8());
}
};
}

private static SliceWriteFunction typedVarcharWriteFunctionWithCollate(String jdbcTypeName)
{
String collation = "COLLATE \"C\"";
String bindExpression = format("CAST(? AS %s) %s", requireNonNull(jdbcTypeName, "jdbcTypeName is null"), collation);

return new SliceWriteFunction()
{
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class PostgreSqlConfig
{
private ArrayMapping arrayMapping = ArrayMapping.DISABLED;
private boolean includeSystemTables;
private boolean enableStringPushdownWithCollate;

public enum ArrayMapping
{
Expand Down Expand Up @@ -55,4 +56,16 @@ public PostgreSqlConfig setIncludeSystemTables(boolean includeSystemTables)
this.includeSystemTables = includeSystemTables;
return this;
}

public boolean isEnableStringPushdownWithCollate()
{
return enableStringPushdownWithCollate;
}

@Config("postgresql.experimental.enable-string-pushdown-with-collate")
public PostgreSqlConfig setEnableStringPushdownWithCollate(boolean enableStringPushdownWithCollate)
{
this.enableStringPushdownWithCollate = enableStringPushdownWithCollate;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@

import java.util.List;

import static io.trino.spi.session.PropertyMetadata.booleanProperty;
import static io.trino.spi.session.PropertyMetadata.enumProperty;

public final class PostgreSqlSessionProperties
implements SessionPropertiesProvider
{
public static final String ARRAY_MAPPING = "array_mapping";
public static final String ENABLE_STRING_PUSHDOWN_WITH_COLLATE = "enable_string_pushdown_with_collate";

private final List<PropertyMetadata<?>> sessionProperties;

Expand All @@ -41,6 +43,11 @@ public PostgreSqlSessionProperties(PostgreSqlConfig postgreSqlConfig)
"Handling of PostgreSql arrays",
ArrayMapping.class,
postgreSqlConfig.getArrayMapping(),
false),
booleanProperty(
ENABLE_STRING_PUSHDOWN_WITH_COLLATE,
"Enable string pushdown with collate (experimental)",
postgreSqlConfig.isEnableStringPushdownWithCollate(),
false));
}

Expand All @@ -54,4 +61,9 @@ public static ArrayMapping getArrayMapping(ConnectorSession session)
{
return session.getProperty(ARRAY_MAPPING, ArrayMapping.class);
}

public static boolean isEnableStringPushdownWithCollate(ConnectorSession session)
{
return session.getProperty(ENABLE_STRING_PUSHDOWN_WITH_COLLATE, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public static DistributedQueryRunner createPostgreSqlQueryRunner(
connectorProperties.putIfAbsent("connection-password", server.getPassword());
connectorProperties.putIfAbsent("allow-drop-table", "true");
connectorProperties.putIfAbsent("postgresql.include-system-tables", "true");
//connectorProperties.putIfAbsent("postgresql.experimental.enable-string-pushdown-with-collate", "true");

queryRunner.installPlugin(new PostgreSqlPlugin());
queryRunner.createCatalog("postgresql", "postgresql", connectorProperties);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public void testDefaults()
{
assertRecordedDefaults(recordDefaults(PostgreSqlConfig.class)
.setArrayMapping(PostgreSqlConfig.ArrayMapping.DISABLED)
.setIncludeSystemTables(false));
.setIncludeSystemTables(false)
.setEnableStringPushdownWithCollate(false));
}

@Test
Expand All @@ -38,11 +39,13 @@ public void testExplicitPropertyMappings()
Map<String, String> properties = new ImmutableMap.Builder<String, String>()
.put("postgresql.array-mapping", "AS_ARRAY")
.put("postgresql.include-system-tables", "true")
.put("postgresql.experimental.enable-string-pushdown-with-collate", "true")
.build();

PostgreSqlConfig expected = new PostgreSqlConfig()
.setArrayMapping(PostgreSqlConfig.ArrayMapping.AS_ARRAY)
.setIncludeSystemTables(true);
.setIncludeSystemTables(true)
.setEnableStringPushdownWithCollate(true);

assertFullMapping(properties, expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
*/
package io.trino.plugin.postgresql;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.RemoteDatabaseEvent;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
Expand All @@ -42,7 +46,10 @@
import java.util.Map;
import java.util.UUID;

import static com.google.common.collect.MoreCollectors.onlyElement;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
Expand Down Expand Up @@ -82,6 +89,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
switch (connectorBehavior) {
case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY:
case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY:
hashhar marked this conversation as resolved.
Show resolved Hide resolved
return false;
Copy link
Member Author

@takezoe takezoe Oct 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional updates are necessary to support aggregation pushdown for varchar columns. Since I would focus on support range predicates pushdown in this pull request, I just introduced a new behavior that indicates whether aggregation pushdown is supported for varchar columns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a GitHub issue and add a TODO so that we can remove the additional behaviour once it's no longer needed (and other people know that this is something they can work on).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense (as long as this fix can be merged) and I would work on it once this pull request is completed. By the way, supporting type sensitive aggregation pushdown in JDBC plugins doesn't seem easy. Fundamental interface changes may be required.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I ran into some similar issues in #7320.

There were some ideas floated in #7320 (comment) (which we didn't end up doing since it looked like a one-off need at that time).

If you already have some direction in your mind it might be helpful to discuss it on #dev on Slack too (if you think the changes will be large and touch the SPI).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, will take a look at #7320 first. Thanks.


case SUPPORTS_TOPN_PUSHDOWN:
Expand Down Expand Up @@ -368,6 +376,59 @@ public void testPredicatePushdown()
anyTree(node(TableScanNode.class))));
}

@Test
public void testStringPushdownWithCollate()
{
Session session = Session.builder(getSession())
.setCatalogSessionProperty("postgresql", "enable_string_pushdown_with_collate", "true")
.build();

// varchar range
assertThat(query(session, "SELECT regionkey, nationkey, name FROM nation WHERE name BETWEEN 'POLAND' AND 'RPA'"))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
.isFullyPushedDown();

// varchar IN with small compaction threshold
assertThat(query(
Session.builder(session)
.setCatalogSessionProperty("postgresql", "domain_compaction_threshold", "1")
.build(),
"SELECT regionkey, nationkey, name FROM nation WHERE name IN ('POLAND', 'ROMANIA', 'VIETNAM')"))
.matches("VALUES " +
"(BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25))), " +
"(BIGINT '2', BIGINT '21', CAST('VIETNAM' AS varchar(25)))")
// Verify that a FilterNode is retained and only a compacted domain is pushed down to connector as a range predicate
.isNotFullyPushedDown(node(FilterNode.class, tableScan(
tableHandle -> {
TupleDomain<ColumnHandle> constraint = ((JdbcTableHandle) tableHandle).getConstraint();
ColumnHandle nameColumn = constraint.getDomains().orElseThrow()
.keySet().stream()
.map(JdbcColumnHandle.class::cast)
.filter(column -> column.getColumnName().equals("name"))
.collect(onlyElement());
return constraint.getDomains().get().get(nameColumn).getValues().getRanges().getOrderedRanges()
.equals(ImmutableList.of(
Range.range(
createVarcharType(25),
utf8Slice("POLAND"), true,
utf8Slice("VIETNAM"), true)));
},
TupleDomain.all(),
ImmutableMap.of())));

// varchar predicate over join
Session joinPushdownEnabled = joinPushdownEnabled(session);
assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.custkey = n.nationkey WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea'"))
.isFullyPushedDown();
hashhar marked this conversation as resolved.
Show resolved Hide resolved

// join on varchar columns is not pushed down
assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.address = n.name"))
.isNotFullyPushedDown(
node(JoinNode.class,
anyTree(node(TableScanNode.class)),
anyTree(node(TableScanNode.class))));
}

@Test
public void testDecimalPredicatePushdown()
throws Exception
Expand Down