diff --git a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java index 1b2b25e5f4b5..1c0b3f9c5629 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java @@ -341,6 +341,21 @@ public QueryAssert matches(MaterializedResult expected) }); } + public final QueryAssert matches(PlanMatchPattern expectedPlan) + { + transaction(runner.getTransactionManager(), runner.getAccessControl()) + .execute(session, session -> { + Plan plan = runner.createPlan(session, query, WarningCollector.NOOP); + assertPlan( + session, + runner.getMetadata(), + noopStatsCalculator(), + plan, + expectedPlan); + }); + return this; + } + public QueryAssert containsAll(@Language("SQL") String query) { MaterializedResult expected = runner.execute(session, query); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java index e03d8dfea5b2..066bca5cfdf4 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveConnectorTest.java @@ -41,7 +41,6 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.planner.Plan; import io.trino.sql.planner.plan.ExchangeNode; -import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.planprinter.IoPlanPrinter.ColumnConstraint; import io.trino.sql.planner.planprinter.IoPlanPrinter.EstimatedStatsAndCost; import io.trino.sql.planner.planprinter.IoPlanPrinter.FormattedDomain; @@ -7355,25 +7354,6 @@ public void testUseSortedProperties() assertUpdate("DROP TABLE " + tableName); } - private Consumer assertPartialLimitWithPreSortedInputsCount(Session session, int expectedCount) - { - return plan -> { - int actualCount = searchFrom(plan.getRoot()) - .where(node -> node instanceof LimitNode && ((LimitNode) node).isPartial() && ((LimitNode) node).requiresPreSortedInputs()) - .findAll() - .size(); - if (actualCount != expectedCount) { - Metadata metadata = getDistributedQueryRunner().getCoordinator().getMetadata(); - String formattedPlan = textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, StatsAndCosts.empty(), session, 0, false); - throw new AssertionError(format( - "Expected [\n%s\n] partial limit but found [\n%s\n] partial limit. Actual plan is [\n\n%s\n]", - expectedCount, - actualCount, - formattedPlan)); - } - }; - } - @Test public void testSelectWithNoColumns() { diff --git a/plugin/trino-phoenix5/pom.xml b/plugin/trino-phoenix5/pom.xml index edfcab397445..46cce6af5942 100644 --- a/plugin/trino-phoenix5/pom.xml +++ b/plugin/trino-phoenix5/pom.xml @@ -161,6 +161,19 @@ + + io.trino + trino-main + test-jar + test + + + + io.trino + trino-parser + test + + io.trino trino-testing diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java index b482d53f5b57..f64f16b34c37 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixMetadata.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.phoenix5; +import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.plugin.jdbc.DefaultJdbcMetadata; import io.trino.plugin.jdbc.JdbcColumnHandle; @@ -32,8 +33,12 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; +import io.trino.spi.connector.LocalProperty; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.SortingProperty; +import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.TrinoPrincipal; import io.trino.spi.statistics.ComputedStatistics; @@ -86,6 +91,22 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, SchemaTableName .orElse(null); } + @Override + public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) + { + JdbcTableHandle tableHandle = (JdbcTableHandle) table; + List> sortingProperties = tableHandle.getSortOrder() + .map(properties -> properties + .stream() + .map(item -> (LocalProperty) new SortingProperty( + item.getColumn(), + item.getSortOrder())) + .collect(toImmutableList())) + .orElse(ImmutableList.of()); + + return new ConnectorTableProperties(TupleDomain.all(), Optional.empty(), Optional.empty(), Optional.empty(), sortingProperties); + } + @Override public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTableHandle table) { diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java index 9ef83453a778..335eb60ad81e 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java @@ -13,10 +13,12 @@ */ package io.trino.plugin.phoenix5; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.plugin.jdbc.UnsupportedTypeHandling; +import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.SqlExecutor; @@ -31,9 +33,26 @@ import java.sql.Statement; import java.util.List; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.UNSUPPORTED_TYPE_HANDLING; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; import static io.trino.plugin.phoenix5.PhoenixQueryRunner.createPhoenixQueryRunner; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.trino.sql.planner.assertions.PlanMatchPattern.limit; +import static io.trino.sql.planner.assertions.PlanMatchPattern.output; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.topN; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER; +import static io.trino.sql.planner.plan.TopNNode.Step.FINAL; +import static io.trino.sql.tree.SortItem.NullOrdering.FIRST; +import static io.trino.sql.tree.SortItem.NullOrdering.LAST; +import static io.trino.sql.tree.SortItem.Ordering.ASCENDING; +import static io.trino.sql.tree.SortItem.Ordering.DESCENDING; +import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertTrue; @@ -307,6 +326,118 @@ public void testMissingColumnsOnInsert() assertQuery("SELECT * FROM test_col_insert", "SELECT 1, 'val1', 'val2'"); } + @Override + public void testTopNPushdown() + { + throw new SkipException("Phoenix does not support topN push down, but instead replaces partial topN with partial Limit."); + } + + @Test + public void testReplacePartialTopNWithLimit() + { + List orderBy = ImmutableList.of(sort("orderkey", ASCENDING, LAST)); + + assertThat(query("SELECT orderkey FROM orders ORDER BY orderkey LIMIT 10")) + .matches(output( + topN(10, orderBy, FINAL, + exchange(LOCAL, GATHER, ImmutableList.of(), + exchange(REMOTE, GATHER, ImmutableList.of(), + limit( + 10, + ImmutableList.of(), + true, + orderBy.stream() + .map(PlanMatchPattern.Ordering::getField) + .collect(toImmutableList()), + tableScan("orders", ImmutableMap.of("orderkey", "orderkey")))))))); + + orderBy = ImmutableList.of(sort("orderkey", ASCENDING, FIRST)); + + assertThat(query("SELECT orderkey FROM orders ORDER BY orderkey NULLS FIRST LIMIT 10")) + .matches(output( + topN(10, orderBy, FINAL, + exchange(LOCAL, GATHER, ImmutableList.of(), + exchange(REMOTE, GATHER, ImmutableList.of(), + limit( + 10, + ImmutableList.of(), + true, + orderBy.stream() + .map(PlanMatchPattern.Ordering::getField) + .collect(toImmutableList()), + tableScan("orders", ImmutableMap.of("orderkey", "orderkey")))))))); + + orderBy = ImmutableList.of(sort("orderkey", DESCENDING, LAST)); + + assertThat(query("SELECT orderkey FROM orders ORDER BY orderkey DESC LIMIT 10")) + .matches(output( + topN(10, orderBy, FINAL, + exchange(LOCAL, GATHER, ImmutableList.of(), + exchange(REMOTE, GATHER, ImmutableList.of(), + limit( + 10, + ImmutableList.of(), + true, + orderBy.stream() + .map(PlanMatchPattern.Ordering::getField) + .collect(toImmutableList()), + tableScan("orders", ImmutableMap.of("orderkey", "orderkey")))))))); + + orderBy = ImmutableList.of(sort("orderkey", ASCENDING, LAST), sort("custkey", ASCENDING, LAST)); + + assertThat(query("SELECT orderkey FROM orders ORDER BY orderkey, custkey LIMIT 10")) + .matches(output( + project( + topN(10, orderBy, FINAL, + exchange(LOCAL, GATHER, ImmutableList.of(), + exchange(REMOTE, GATHER, ImmutableList.of(), + limit( + 10, + ImmutableList.of(), + true, + orderBy.stream() + .map(PlanMatchPattern.Ordering::getField) + .collect(toImmutableList()), + tableScan("orders", ImmutableMap.of("orderkey", "orderkey", "custkey", "custkey"))))))))); + + orderBy = ImmutableList.of(sort("orderkey", ASCENDING, LAST), sort("custkey", DESCENDING, LAST)); + + assertThat(query("SELECT orderkey FROM orders ORDER BY orderkey, custkey DESC LIMIT 10")) + .matches(output( + project( + topN(10, orderBy, FINAL, + exchange(LOCAL, GATHER, ImmutableList.of(), + exchange(REMOTE, GATHER, ImmutableList.of(), + limit( + 10, + ImmutableList.of(), + true, + orderBy.stream() + .map(PlanMatchPattern.Ordering::getField) + .collect(toImmutableList()), + tableScan("orders", ImmutableMap.of("orderkey", "orderkey", "custkey", "custkey"))))))))); + } + + /* + * Make sure that partial topN is replaced with a partial limit when the input is presorted. + */ + @Test + public void testUseSortedPropertiesForPartialTopNElimination() + { + String tableName = "test_propagate_table_scan_sorting_properties"; + // salting ensures multiple splits + String createTableSql = format("" + + "CREATE TABLE %s WITH (salt_buckets = 5) AS " + + "SELECT * FROM tpch.tiny.customer", + tableName); + assertUpdate(createTableSql, 1500L); + + String expected = "SELECT custkey FROM customer ORDER BY 1 NULLS FIRST LIMIT 100"; + String actual = format("SELECT custkey FROM %s ORDER BY 1 NULLS FIRST LIMIT 100", tableName); + assertQuery(getSession(), actual, expected, assertPartialLimitWithPreSortedInputsCount(getSession(), 1)); + assertUpdate("DROP TABLE " + tableName); + } + @Override protected TestTable createTableWithDoubleAndRealColumns(String name, List rows) { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 22336523de4e..6565b02997b5 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -14,17 +14,24 @@ package io.trino.testing; import io.trino.Session; +import io.trino.cost.StatsAndCosts; +import io.trino.metadata.Metadata; import io.trino.sql.analyzer.FeaturesConfig.JoinDistributionType; +import io.trino.sql.planner.Plan; +import io.trino.sql.planner.plan.LimitNode; import io.trino.testing.sql.TestTable; import org.intellij.lang.annotations.Language; import org.testng.SkipException; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.function.Consumer; import java.util.stream.Stream; import static io.trino.SystemSessionProperties.IGNORE_STATS_CALCULATOR_FAILURES; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static io.trino.sql.planner.planprinter.PlanPrinter.textLogicalPlan; import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertContains; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_ARRAY; @@ -42,6 +49,7 @@ import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN; import static io.trino.testing.assertions.Assert.assertEquals; import static io.trino.testing.sql.TestTable.randomTableSuffix; +import static java.lang.String.format; import static java.lang.String.join; import static java.util.Collections.nCopies; import static org.assertj.core.api.Assertions.assertThat; @@ -776,4 +784,23 @@ public void testRowLevelDelete() assertQuery("SELECT count(*) FROM " + table.getName(), "VALUES 4"); } } + + protected Consumer assertPartialLimitWithPreSortedInputsCount(Session session, int expectedCount) + { + return plan -> { + int actualCount = searchFrom(plan.getRoot()) + .where(node -> node instanceof LimitNode && ((LimitNode) node).isPartial() && ((LimitNode) node).requiresPreSortedInputs()) + .findAll() + .size(); + if (actualCount != expectedCount) { + Metadata metadata = getDistributedQueryRunner().getCoordinator().getMetadata(); + String formattedPlan = textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, StatsAndCosts.empty(), session, 0, false); + throw new AssertionError(format( + "Expected [\n%s\n] partial limit but found [\n%s\n] partial limit. Actual plan is [\n\n%s\n]", + expectedCount, + actualCount, + formattedPlan)); + } + }; + } }