diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java index b11a39057eeed..c9b2e8d8326a3 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClusteringPredicatesExtractor.java @@ -16,19 +16,20 @@ import com.datastax.driver.core.VersionNumber; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.trino.plugin.cassandra.util.CassandraCqlUtils; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; -import java.util.ArrayList; import java.util.List; -import java.util.Map; +import java.util.Set; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class CassandraClusteringPredicatesExtractor { @@ -48,15 +49,14 @@ public String getClusteringKeyPredicates() public TupleDomain getUnenforcedConstraints() { - Map pushedDown = clusteringPushDownResult.getDomains(); - return predicates.filter(((columnHandle, domain) -> !pushedDown.containsKey(columnHandle))); + return predicates.filter(((columnHandle, domain) -> !clusteringPushDownResult.hasBeenFullyPushed(columnHandle))); } private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) { - ImmutableMap.Builder domainsBuilder = ImmutableMap.builder(); + ImmutableSet.Builder fullyPushedColumnPredicates = ImmutableSet.builder(); ImmutableList.Builder clusteringColumnSql = ImmutableList.builder(); - int currentClusteringColumn = 0; + int allProcessedClusteringColumns = 0; for (CassandraColumnHandle columnHandle : clusteringColumns) { Domain domain = predicates.getDomains().get().get(columnHandle); if (domain == null) { @@ -65,64 +65,46 @@ private static ClusteringPushDownResult getClusteringKeysSet(List { - List singleValues = new ArrayList<>(); - List rangeConjuncts = new ArrayList<>(); - String predicate = null; - - for (Range range : ranges.getOrderedRanges()) { - if (range.isAll()) { - return null; - } - if (range.isSingleValue()) { - singleValues.add(columnHandle.getCassandraType().toCqlLiteral(range.getSingleValue())); - } - else { - if (!range.isLowUnbounded()) { - String lowBound = columnHandle.getCassandraType().toCqlLiteral(range.getLowBoundedValue()); - rangeConjuncts.add(format( - "%s %s %s", - CassandraCqlUtils.validColumnName(columnHandle.getName()), - range.isLowInclusive() ? ">=" : ">", - lowBound)); - } - if (!range.isHighUnbounded()) { - String highBound = columnHandle.getCassandraType().toCqlLiteral(range.getHighBoundedValue()); - rangeConjuncts.add(format( - "%s %s %s", - CassandraCqlUtils.validColumnName(columnHandle.getName()), - range.isHighInclusive() ? "<=" : "<", - highBound)); - } - } - } - - if (!singleValues.isEmpty() && !rangeConjuncts.isEmpty()) { - return null; + if (ranges.getRangeCount() == 1) { + fullyPushedColumnPredicates.add(columnHandle); + return translateRangeIntoCql(columnHandle, getOnlyElement(ranges.getOrderedRanges())); } - if (!singleValues.isEmpty()) { - if (singleValues.size() == 1) { - predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " = " + singleValues.get(0); - } - else { - predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" - + Joiner.on(",").join(singleValues) + ")"; + if (ranges.getOrderedRanges().stream().allMatch(Range::isSingleValue)) { + if (isInExpressionNotAllowed(clusteringColumns, cassandraVersion, currentlyProcessedClusteringColumn)) { + return translateRangeIntoCql(columnHandle, ranges.getSpan()); } + + String inValues = ranges.getOrderedRanges().stream() + .map(range -> toCqlLiteral(columnHandle, range.getSingleValue())) + .collect(joining(",")); + fullyPushedColumnPredicates.add(columnHandle); + return CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + inValues + ")"; } - else if (!rangeConjuncts.isEmpty()) { - predicate = Joiner.on(" AND ").join(rangeConjuncts); - } - return predicate; + return translateRangeIntoCql(columnHandle, ranges.getSpan()); }, discreteValues -> { if (discreteValues.isInclusive()) { - ImmutableList.Builder discreteValuesList = ImmutableList.builder(); - for (Object discreteValue : discreteValues.getValues()) { - discreteValuesList.add(columnHandle.getCassandraType().toCqlLiteral(discreteValue)); + if (discreteValues.getValuesCount() == 0) { + return null; + } + if (discreteValues.getValuesCount() == 1) { + fullyPushedColumnPredicates.add(columnHandle); + return format("%s = %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + toCqlLiteral(columnHandle, getOnlyElement(discreteValues.getValues()))); } - String predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" - + Joiner.on(",").join(discreteValuesList.build()) + ")"; - return predicate; + if (isInExpressionNotAllowed(clusteringColumns, cassandraVersion, currentlyProcessedClusteringColumn)) { + return null; + } + + String inValues = discreteValues.getValues().stream() + .map(columnHandle.getCassandraType()::toCqlLiteral) + .collect(joining(",")); + fullyPushedColumnPredicates.add(columnHandle); + return CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + inValues + " )"; } return null; }, allOrNone -> null); @@ -130,37 +112,83 @@ else if (!rangeConjuncts.isEmpty()) { if (predicateString == null) { break; } - // IN restriction only on last clustering column for Cassandra version = 2.1 - if (predicateString.contains(" IN (") && cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentClusteringColumn != (clusteringColumns.size() - 1)) { - break; - } clusteringColumnSql.add(predicateString); - domainsBuilder.put(columnHandle, domain); // Check for last clustering column should only be restricted by range condition if (predicateString.contains(">") || predicateString.contains("<")) { break; } - currentClusteringColumn++; + allProcessedClusteringColumns++; } List clusteringColumnPredicates = clusteringColumnSql.build(); - return new ClusteringPushDownResult(domainsBuilder.build(), Joiner.on(" AND ").join(clusteringColumnPredicates)); + return new ClusteringPushDownResult(fullyPushedColumnPredicates.build(), Joiner.on(" AND ").join(clusteringColumnPredicates)); + } + + /** + * IN restriction allowed only on last clustering column for Cassandra version <= 2.2.0 + */ + private static boolean isInExpressionNotAllowed(List clusteringColumns, VersionNumber cassandraVersion, int currentlyProcessedClusteringColumn) + { + return cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentlyProcessedClusteringColumn != (clusteringColumns.size() - 1); + } + + private static String toCqlLiteral(CassandraColumnHandle columnHandle, Object value) + { + return columnHandle.getCassandraType().toCqlLiteral(value); + } + + private static String translateRangeIntoCql(CassandraColumnHandle columnHandle, Range range) + { + if (range.isAll()) { + return null; + } + if (range.isSingleValue()) { + return format("%s = %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + toCqlLiteral(columnHandle, range.getSingleValue())); + } + + String lowerBoundPredicate = null; + String upperBoundPredicate = null; + if (!range.isLowUnbounded()) { + String lowBound = toCqlLiteral(columnHandle, range.getLowBoundedValue()); + lowerBoundPredicate = format( + "%s %s %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + range.isLowInclusive() ? ">=" : ">", + lowBound); + } + if (!range.isHighUnbounded()) { + String highBound = toCqlLiteral(columnHandle, range.getHighBoundedValue()); + upperBoundPredicate = format( + "%s %s %s", + CassandraCqlUtils.validColumnName(columnHandle.getName()), + range.isHighInclusive() ? "<=" : "<", + highBound); + } + if (lowerBoundPredicate != null && upperBoundPredicate != null) { + return format("%s AND %s ", lowerBoundPredicate, upperBoundPredicate); + } + if (lowerBoundPredicate != null) { + return lowerBoundPredicate; + } + return upperBoundPredicate; } private static class ClusteringPushDownResult { - private final Map domains; + private final Set fullyPushedColumnPredicates; private final String domainQuery; - public ClusteringPushDownResult(Map domains, String domainQuery) + public ClusteringPushDownResult(Set fullyPushedColumnPredicates, String domainQuery) { - this.domains = requireNonNull(ImmutableMap.copyOf(domains)); + this.fullyPushedColumnPredicates = ImmutableSet.copyOf(requireNonNull(fullyPushedColumnPredicates, "fullyPushedColumnPredicates is null")); this.domainQuery = requireNonNull(domainQuery); } - public Map getDomains() + public boolean hasBeenFullyPushed(ColumnHandle column) { - return domains; + return fullyPushedColumnPredicates.contains(column); } public String getDomainQuery() diff --git a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java index 470a429e63666..070459f42799e 100644 --- a/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java +++ b/plugin/trino-cassandra/src/test/java/io/trino/plugin/cassandra/TestCassandraConnectorTest.java @@ -409,6 +409,19 @@ public void testClusteringKeyOnlyPushdown() assertEquals(execute(sql).getRowCount(), 1); } + @Test + public void testNotEqualPredicateOnClusteringColumn() + { + String sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one != 'clust_one'"; + assertEquals(execute(sql).getRowCount(), 0); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two != 2"; + assertEquals(execute(sql).getRowCount(), 3); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two >= 2 AND clust_two != 3"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two > 2 AND clust_two != 3"; + assertEquals(execute(sql).getRowCount(), 1); + } + @Test public void testClusteringKeyPushdownInequality() {