diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 7fe07ced72337..d9446bb3c24ea 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -122,6 +122,7 @@ import io.trino.sql.planner.iterative.rule.PruneSortColumns; import io.trino.sql.planner.iterative.rule.PruneSpatialJoinChildrenColumns; import io.trino.sql.planner.iterative.rule.PruneSpatialJoinColumns; +import io.trino.sql.planner.iterative.rule.PruneTableExecuteSourceColumns; import io.trino.sql.planner.iterative.rule.PruneTableScanColumns; import io.trino.sql.planner.iterative.rule.PruneTableWriterSourceColumns; import io.trino.sql.planner.iterative.rule.PruneTopNColumns; @@ -356,6 +357,7 @@ public PlanOptimizers( new PruneSortColumns(), new PruneSpatialJoinChildrenColumns(), new PruneSpatialJoinColumns(), + new PruneTableExecuteSourceColumns(), new PruneTableScanColumns(metadata), new PruneTableWriterSourceColumns(), new PruneTopNColumns(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneTableExecuteSourceColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneTableExecuteSourceColumns.java new file mode 100644 index 0000000000000..df535be17c06c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneTableExecuteSourceColumns.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableSet; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.TableExecuteNode; + +import static io.trino.sql.planner.iterative.rule.Util.restrictChildOutputs; +import static io.trino.sql.planner.plan.Patterns.tableExecute; + +public class PruneTableExecuteSourceColumns + implements Rule +{ + private static final Pattern PATTERN = tableExecute(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableExecuteNode tableExecuteNode, Captures captures, Context context) + { + ImmutableSet.Builder requiredInputs = ImmutableSet.builder() + .addAll(tableExecuteNode.getColumns()); + + if (tableExecuteNode.getPartitioningScheme().isPresent()) { + PartitioningScheme partitioningScheme = tableExecuteNode.getPartitioningScheme().get(); + partitioningScheme.getPartitioning().getColumns().forEach(requiredInputs::add); + partitioningScheme.getHashColumn().ifPresent(requiredInputs::add); + } + + return restrictChildOutputs(context.getIdAllocator(), tableExecuteNode, requiredInputs.build()) + .map(Result::ofPlanNode) + .orElse(Result.empty()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingTableExecuteHandle.java b/core/trino-main/src/main/java/io/trino/testing/TestingTableExecuteHandle.java new file mode 100644 index 0000000000000..fcd3b7703fcd5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/testing/TestingTableExecuteHandle.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import io.trino.spi.connector.ConnectorTableExecuteHandle; + +public class TestingTableExecuteHandle + implements ConnectorTableExecuteHandle {} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 7e03562582cd3..0375a3eb3ce53 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -55,6 +55,7 @@ import io.trino.sql.planner.plan.SemiJoinNode; import io.trino.sql.planner.plan.SortNode; import io.trino.sql.planner.plan.SpatialJoinNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -854,6 +855,11 @@ public static PlanMatchPattern tableWriter(List columns, List co return node(TableWriterNode.class, source).with(new TableWriterMatcher(columns, columnNames)); } + public static PlanMatchPattern tableExecute(List columns, List columnNames, PlanMatchPattern source) + { + return node(TableExecuteNode.class, source).with(new TableExecuteMatcher(columns, columnNames)); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableExecuteMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableExecuteMatcher.java new file mode 100644 index 0000000000000..4f1eec03486d2 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableExecuteMatcher.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.sql.planner.assertions; + +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.TableExecuteNode; + +import java.util.List; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; +import static io.trino.sql.planner.assertions.MatchResult.match; + +public class TableExecuteMatcher + implements Matcher +{ + private final List columns; + private final List columnNames; + + public TableExecuteMatcher(List columns, List columnNames) + { + this.columns = columns; + this.columnNames = columnNames; + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableExecuteNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableExecuteNode tableExecuteNode = (TableExecuteNode) node; + if (!tableExecuteNode.getColumnNames().equals(columnNames)) { + return NO_MATCH; + } + + if (!columns.stream() + .map(symbol -> Symbol.from(symbolAliases.get(symbol))) + .collect(toImmutableList()) + .equals(tableExecuteNode.getColumns())) { + return NO_MATCH; + } + + return match(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columns", columns) + .add("columnNames", columnNames) + .toString(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java new file mode 100644 index 0000000000000..cb10686aaf433 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneTableExecuteSourceColumns.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableExecute; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableExecuteSourceColumns + extends BaseRuleTest +{ + @Test + public void testNotAllInputsReferenced() + { + tester().assertThat(new PruneTableExecuteSourceColumns()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.tableExecute( + ImmutableList.of(a), + ImmutableList.of("column_a"), + p.values(a, b)); + }) + .matches( + tableExecute( + ImmutableList.of("a"), + ImmutableList.of("column_a"), + strictProject( + ImmutableMap.of("a", expression("a")), + values("a", "b")))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneTableExecuteSourceColumns()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + return p.tableExecute( + ImmutableList.of(a, b), + ImmutableList.of("column_a", "column_b"), + p.values(a, b)); + }) + .doesNotFire(); + } + + @Test + public void testDoNotPrunePartitioningSchemeSymbols() + { + tester().assertThat(new PruneTableExecuteSourceColumns()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol partition = p.symbol("partition"); + Symbol hash = p.symbol("hash"); + return p.tableExecute( + ImmutableList.of(a), + ImmutableList.of("column_a"), + Optional.of(p.partitioningScheme( + ImmutableList.of(partition, hash), + ImmutableList.of(partition), + hash)), + Optional.empty(), + p.values(a, partition, hash)); + }) + .doesNotFire(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index af1f068f62789..8f0c6bd7270bf 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -25,6 +25,7 @@ import io.trino.metadata.IndexHandle; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.SchemaTableName; @@ -79,6 +80,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticAggregations; import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -100,6 +102,7 @@ import io.trino.testing.TestingHandle; import io.trino.testing.TestingMetadata.TestingColumnHandle; import io.trino.testing.TestingMetadata.TestingTableHandle; +import io.trino.testing.TestingTableExecuteHandle; import io.trino.testing.TestingTransactionHandle; import java.util.ArrayList; @@ -1128,6 +1131,36 @@ public TableWriterNode tableWriter( statisticAggregationsDescriptor); } + public TableExecuteNode tableExecute(List columns, List columnNames, PlanNode source) + { + return tableExecute(columns, columnNames, Optional.empty(), Optional.empty(), source); + } + + public TableExecuteNode tableExecute( + List columns, + List columnNames, + Optional partitioningScheme, + Optional preferredPartitioningScheme, + PlanNode source) + { + return new TableExecuteNode( + idAllocator.getNextId(), + source, + new TableWriterNode.TableExecuteTarget( + new TableExecuteHandle( + new CatalogName("testConnector"), + TestingTransactionHandle.create(), + new TestingTableExecuteHandle()), + Optional.empty(), + SchemaTableName.schemaTableName("testschema", "testtable")), + symbol("partialrows", BIGINT), + symbol("fragment", VARBINARY), + columns, + columnNames, + partitioningScheme, + preferredPartitioningScheme); + } + public PartitioningScheme partitioningScheme(List outputSymbols, List partitioningSymbols, Symbol hashSymbol) { return new PartitioningScheme(Partitioning.create(