From 750000ca5797351866dcc0e846f9e52becf12b90 Mon Sep 17 00:00:00 2001 From: Szymon Homa Date: Tue, 3 Aug 2021 14:08:05 +0200 Subject: [PATCH] Fix Cassandra OrderedRangeSet pushdown Due to the fact that Cassandra supports only =, <, >, .... IN (....) ... AND ... When we have single single-valued range, we use =. When we have single range, we use low bound < x AND x < high bound (or <= when appropriate) When we have multiple single-valued range, we use IN (...). In all other cases, including when IN is not supported in Cassandra, we push down min/max bounds (domain.getValues().getRanges().getSpan()) using low bound < x AND x < high bound (or <= when appropriate) --- ...assandraClusteringPredicatesExtractor.java | 145 +++++++++++------- .../cassandra/TestCassandraConnectorTest.java | 13 ++ 2 files changed, 100 insertions(+), 58 deletions(-) diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java index 109acfe626781..c9b2e8d8326a3 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java @@ -23,12 +23,13 @@ import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; -import java.util.ArrayList; import java.util.List; import java.util.Set; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class CassandraClusteringPredicatesExtractor { @@ -55,7 +56,7 @@ private static ClusteringPushDownResult getClusteringKeysSet(List fullyPushedColumnPredicates = ImmutableSet.builder(); ImmutableList.Builder clusteringColumnSql = ImmutableList.builder(); - int currentClusteringColumn = 0; + int allProcessedClusteringColumns = 0; for (CassandraColumnHandle columnHandle : clusteringColumns) { Domain domain = predicates.getDomains().get().get(columnHandle); if (domain == null) { @@ -64,64 +65,46 @@ private static ClusteringPushDownResult getClusteringKeysSet(List { - List singleValues = new ArrayList<>(); - List rangeConjuncts = new ArrayList<>(); - String predicate = null; - - for (Range range : ranges.getOrderedRanges()) { - if (range.isAll()) { - return null; - } - if (range.isSingleValue()) { - singleValues.add(columnHandle.getCassandraType().toCqlLiteral(range.getSingleValue())); - } - else { - if (!range.isLowUnbounded()) { - String lowBound = columnHandle.getCassandraType().toCqlLiteral(range.getLowBoundedValue()); - rangeConjuncts.add(format( - "%s %s %s", - CassandraCqlUtils.validColumnName(columnHandle.getName()), - range.isLowInclusive() ? ">=" : ">", - lowBound)); - } - if (!range.isHighUnbounded()) { - String highBound = columnHandle.getCassandraType().toCqlLiteral(range.getHighBoundedValue()); - rangeConjuncts.add(format( - "%s %s %s", - CassandraCqlUtils.validColumnName(columnHandle.getName()), - range.isHighInclusive() ? "<=" : "<", - highBound)); - } - } - } - - if (!singleValues.isEmpty() && !rangeConjuncts.isEmpty()) { - return null; + if (ranges.getRangeCount() == 1) { + fullyPushedColumnPredicates.add(columnHandle); + return translateRangeIntoCql(columnHandle, getOnlyElement(ranges.getOrderedRanges())); } - if (!singleValues.isEmpty()) { - if (singleValues.size() == 1) { - predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " = " + singleValues.get(0); - } - else { - predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" - + Joiner.on(",").join(singleValues) + ")"; + if (ranges.getOrderedRanges().stream().allMatch(Range::isSingleValue)) { + if (isInExpressionNotAllowed(clusteringColumns, cassandraVersion, currentlyProcessedClusteringColumn)) { + return translateRangeIntoCql(columnHandle, ranges.getSpan()); } + + String inValues = ranges.getOrderedRanges().stream() + .map(range -> toCqlLiteral(columnHandle, range.getSingleValue())) + .collect(joining(",")); + fullyPushedColumnPredicates.add(columnHandle); + return CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + inValues + ")"; } - else if (!rangeConjuncts.isEmpty()) { - predicate = Joiner.on(" AND ").join(rangeConjuncts); - } - return predicate; + return translateRangeIntoCql(columnHandle, ranges.getSpan()); }, discreteValues -> { if (discreteValues.isInclusive()) { - ImmutableList.Builder discreteValuesList = ImmutableList.builder(); - for (Object discreteValue : discreteValues.getValues()) { - discreteValuesList.add(columnHandle.getCassandraType().toCqlLiteral(discreteValue)); + if (discreteValues.getValuesCount() == 0) { + return null; + } + if (discreteValues.getValuesCount() == 1) { + fullyPushedColumnPredicates.add(columnHandle); + return format("%s = %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + toCqlLiteral(columnHandle, getOnlyElement(discreteValues.getValues()))); } - String predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" - + Joiner.on(",").join(discreteValuesList.build()) + ")"; - return predicate; + if (isInExpressionNotAllowed(clusteringColumns, cassandraVersion, currentlyProcessedClusteringColumn)) { + return null; + } + + String inValues = discreteValues.getValues().stream() + .map(columnHandle.getCassandraType()::toCqlLiteral) + .collect(joining(",")); + fullyPushedColumnPredicates.add(columnHandle); + return CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + inValues + " )"; } return null; }, allOrNone -> null); @@ -129,23 +112,69 @@ else if (!rangeConjuncts.isEmpty()) { if (predicateString == null) { break; } - // IN restriction only on last clustering column for Cassandra version = 2.1 - if (predicateString.contains(" IN (") && cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentClusteringColumn != (clusteringColumns.size() - 1)) { - break; - } clusteringColumnSql.add(predicateString); - fullyPushedColumnPredicates.add(columnHandle); // Check for last clustering column should only be restricted by range condition if (predicateString.contains(">") || predicateString.contains("<")) { break; } - currentClusteringColumn++; + allProcessedClusteringColumns++; } List clusteringColumnPredicates = clusteringColumnSql.build(); return new ClusteringPushDownResult(fullyPushedColumnPredicates.build(), Joiner.on(" AND ").join(clusteringColumnPredicates)); } + /** + * IN restriction allowed only on last clustering column for Cassandra version <= 2.2.0 + */ + private static boolean isInExpressionNotAllowed(List clusteringColumns, VersionNumber cassandraVersion, int currentlyProcessedClusteringColumn) + { + return cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentlyProcessedClusteringColumn != (clusteringColumns.size() - 1); + } + + private static String toCqlLiteral(CassandraColumnHandle columnHandle, Object value) + { + return columnHandle.getCassandraType().toCqlLiteral(value); + } + + private static String translateRangeIntoCql(CassandraColumnHandle columnHandle, Range range) + { + if (range.isAll()) { + return null; + } + if (range.isSingleValue()) { + return format("%s = %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + toCqlLiteral(columnHandle, range.getSingleValue())); + } + + String lowerBoundPredicate = null; + String upperBoundPredicate = null; + if (!range.isLowUnbounded()) { + String lowBound = toCqlLiteral(columnHandle, range.getLowBoundedValue()); + lowerBoundPredicate = format( + "%s %s %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + range.isLowInclusive() ? ">=" : ">", + lowBound); + } + if (!range.isHighUnbounded()) { + String highBound = toCqlLiteral(columnHandle, range.getHighBoundedValue()); + upperBoundPredicate = format( + "%s %s %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + range.isHighInclusive() ? "<=" : "<", + highBound); + } + if (lowerBoundPredicate != null && upperBoundPredicate != null) { + return format("%s AND %s ", lowerBoundPredicate, upperBoundPredicate); + } + if (lowerBoundPredicate != null) { + return lowerBoundPredicate; + } + return upperBoundPredicate; + } + private static class ClusteringPushDownResult { private final Set fullyPushedColumnPredicates; diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java index 470a429e63666..070459f42799e 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java @@ -409,6 +409,19 @@ public void testClusteringKeyOnlyPushdown() assertEquals(execute(sql).getRowCount(), 1); } + @Test + public void testNotEqualPredicateOnClusteringColumn() + { + String sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one != 'clust_one'"; + assertEquals(execute(sql).getRowCount(), 0); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two != 2"; + assertEquals(execute(sql).getRowCount(), 3); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two >= 2 AND clust_two != 3"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two > 2 AND clust_two != 3"; + assertEquals(execute(sql).getRowCount(), 1); + } + @Test public void testClusteringKeyPushdownInequality() {