diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcMetadata.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcMetadata.java index bc96bac745dea..e384d48555d67 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcMetadata.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcMetadata.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.prestosql.spi.PrestoException; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorInsertTableHandle; @@ -33,7 +34,6 @@ import io.prestosql.spi.connector.ConstraintApplicationResult; import io.prestosql.spi.connector.LimitApplicationResult; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; import io.prestosql.spi.connector.SystemTable; diff --git a/presto-bigquery/src/main/java/io/prestosql/plugin/bigquery/BigQueryMetadata.java b/presto-bigquery/src/main/java/io/prestosql/plugin/bigquery/BigQueryMetadata.java index 9f938122e43e2..c87027ce6a540 100644 --- a/presto-bigquery/src/main/java/io/prestosql/plugin/bigquery/BigQueryMetadata.java +++ b/presto-bigquery/src/main/java/io/prestosql/plugin/bigquery/BigQueryMetadata.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import io.airlift.log.Logger; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorMetadata; @@ -37,7 +38,6 @@ import io.prestosql.spi.connector.LimitApplicationResult; import io.prestosql.spi.connector.NotFoundException; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; import io.prestosql.spi.connector.TableNotFoundException; diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java index 23c3212e698ef..4e544b249ac30 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java @@ -49,6 +49,7 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.StandardErrorCode; import io.prestosql.spi.block.Block; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -68,7 +69,6 @@ import io.prestosql.spi.connector.DiscretePredicates; import io.prestosql.spi.connector.InMemoryRecordSet; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.SchemaNotFoundException; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java index 69e39b2a3e5a6..a0303658ac7bc 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java @@ -60,6 +60,7 @@ import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorInsertTableHandle; @@ -84,7 +85,6 @@ import io.prestosql.spi.connector.ConstraintApplicationResult; import io.prestosql.spi.connector.DiscretePredicates; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.RecordCursor; import io.prestosql.spi.connector.RecordPageSource; import io.prestosql.spi.connector.SchemaTableName; @@ -3057,7 +3057,7 @@ public void testApplyProjection() assertProjectionResult(projectionResult, false, expectedProjections, expectedAssignments); // Round-2: input projections [symbol_2.int0 and onelevelrow0#f_int0]. Virtual handle is reused. - ProjectionApplicationResult.Assignment newlyCreatedColumn = getOnlyElement(projectionResult.get().getAssignments().stream() + Assignment newlyCreatedColumn = getOnlyElement(projectionResult.get().getAssignments().stream() .filter(handle -> handle.getVariable().equals("onelevelrow0#f_int0")) .collect(toList())); inputAssignments = ImmutableMap.builder() diff --git a/presto-kudu/src/main/java/io/prestosql/plugin/kudu/KuduMetadata.java b/presto-kudu/src/main/java/io/prestosql/plugin/kudu/KuduMetadata.java index b50277e6b0c4c..0fa95aff34fc5 100755 --- a/presto-kudu/src/main/java/io/prestosql/plugin/kudu/KuduMetadata.java +++ b/presto-kudu/src/main/java/io/prestosql/plugin/kudu/KuduMetadata.java @@ -18,6 +18,7 @@ import io.airlift.slice.Slice; import io.prestosql.plugin.kudu.properties.KuduTableProperties; import io.prestosql.plugin.kudu.properties.PartitionDesign; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorInsertTableHandle; @@ -33,7 +34,6 @@ import io.prestosql.spi.connector.ConstraintApplicationResult; import io.prestosql.spi.connector.NotFoundException; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; import io.prestosql.spi.expression.ConnectorExpression; diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index c064dc38d749b..2241a084ef62f 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -21,6 +21,8 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.BlockEncoding; import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.AggregationApplicationResult; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -353,6 +355,13 @@ public interface Metadata Optional applySample(Session session, TableHandle table, SampleType sampleType, double sampleRatio); + Optional> applyAggregation( + Session session, + TableHandle table, + List aggregations, + Map assignments, + List> groupingSets); + default void validateScan(Session session, TableHandle table) {} // diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 53719bff55442..f12a45082baba 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -44,6 +44,9 @@ import io.prestosql.spi.block.SingleMapBlockEncoding; import io.prestosql.spi.block.SingleRowBlockEncoding; import io.prestosql.spi.block.VariableWidthBlockEncoding; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.AggregationApplicationResult; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -1108,6 +1111,57 @@ public Optional applySample(Session session, TableHandle table, Sam Optional.empty())); } + @Override + public Optional> applyAggregation( + Session session, + TableHandle table, + List aggregations, + Map assignments, + List> groupingSets) + { + CatalogName catalogName = table.getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + + if (metadata.usesLegacyTableLayouts()) { + return Optional.empty(); + } + + ConnectorSession connectorSession = session.toConnectorSession(catalogName); + return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets) + .map(result -> { + verifyProjection(table, result.getProjections(), result.getAssignments(), aggregations.size()); + + return new AggregationApplicationResult<>( + new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()), + result.getProjections(), + result.getAssignments(), + result.getGroupingColumnMapping()); + }); + } + + private void verifyProjection(TableHandle table, List projections, List assignments, int expectedProjectionSize) + { + projections.forEach(projection -> requireNonNull(projection, "one of the projections is null")); + assignments.forEach(assignment -> requireNonNull(assignment, "one of the assignments is null")); + + verify( + expectedProjectionSize == projections.size(), + "ConnectorMetadata returned invalid number of projections: %s instead of %s for %s", + projections.size(), + expectedProjectionSize, + table); + + Set assignedVariables = assignments.stream() + .map(Assignment::getVariable) + .collect(toImmutableSet()); + projections.stream() + .flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream()) + .map(Variable::getName) + .filter(variableName -> !assignedVariables.contains(variableName)) + .findAny() + .ifPresent(variableName -> { throw new IllegalStateException("Unbound variable: " + variableName); }); + } + @Override public void validateScan(Session session, TableHandle table) { @@ -1146,24 +1200,7 @@ public Optional> applyProjection(Sessio ConnectorSession connectorSession = session.toConnectorSession(catalogName); return metadata.applyProjection(connectorSession, table.getConnectorHandle(), projections, assignments) .map(result -> { - result.getProjections().forEach(projection -> requireNonNull(projection, "one of the projections is null")); - result.getAssignments().forEach(assignment -> requireNonNull(assignment, "one of the assignments is null")); - - verify( - projections.size() == result.getProjections().size(), - "ConnectorMetadata returned invalid number of projections: %s instead of %s for %s", - result.getProjections().size(), - projections.size(), - table); - - Set assignedVariables = result.getAssignments().stream() - .map(ProjectionApplicationResult.Assignment::getVariable) - .collect(toImmutableSet()); - result.getProjections().stream() - .flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream()) - .map(Variable::getName) - .filter(variableName -> !assignedVariables.contains(variableName)) - .findAny().ifPresent(variableName -> { throw new IllegalStateException("Unbound variable: " + variableName); }); + verifyProjection(table, result.getProjections(), result.getAssignments(), projections.size()); return new ProjectionApplicationResult<>( new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 98fc3de02dfbf..c5b33a2998e58 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -106,6 +106,7 @@ import io.prestosql.sql.planner.iterative.rule.PruneUnnestSourceColumns; import io.prestosql.sql.planner.iterative.rule.PruneValuesColumns; import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns; +import io.prestosql.sql.planner.iterative.rule.PushAggregationIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushDeleteIntoConnector; import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughFilter; @@ -508,6 +509,7 @@ public PlanOptimizers( .add(new PushLimitIntoTableScan(metadata)) .add(new PushPredicateIntoTableScan(metadata, typeAnalyzer)) .add(new PushSampleIntoTableScan(metadata)) + .add(new PushAggregationIntoTableScan(metadata)) .build()), new IterativeOptimizer( ruleStats, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java index 0f29868ea23f1..4ddfdf9fcbe8a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PruneTableScanColumns.java @@ -17,6 +17,7 @@ import io.prestosql.Session; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ProjectionApplicationResult; import io.prestosql.spi.expression.ConnectorExpression; @@ -88,7 +89,7 @@ public static Optional pruneColumns(Metadata metadata, TypeProvider ty handle = result.get().getHandle(); Map assignments = result.get().getAssignments().stream() - .collect(toImmutableMap(ProjectionApplicationResult.Assignment::getVariable, ProjectionApplicationResult.Assignment::getColumn)); + .collect(toImmutableMap(Assignment::getVariable, Assignment::getColumn)); ImmutableMap.Builder builder = ImmutableMap.builder(); for (int i = 0; i < newOutputs.size(); i++) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java new file mode 100644 index 0000000000000..0f8a4bc45d127 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.Signature; +import io.prestosql.metadata.TableHandle; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.AggregationApplicationResult; +import io.prestosql.spi.connector.Assignment; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.SortItem; +import io.prestosql.spi.connector.SortOrder; +import io.prestosql.spi.expression.ConnectorExpression; +import io.prestosql.spi.expression.Variable; +import io.prestosql.sql.planner.ConnectorExpressionTranslator; +import io.prestosql.sql.planner.LiteralEncoder; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.AggregationNode; +import io.prestosql.sql.planner.plan.AggregationNode.GroupingSetDescriptor; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.plan.Patterns.aggregation; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.tableScan; + +public class PushAggregationIntoTableScan + implements Rule +{ + private static final Capture TABLE_SCAN = newCapture(); + + private static final Pattern PATTERN = + aggregation() + // skip arguments that are, for instance, lambda expressions + .matching(PushAggregationIntoTableScan::allArgumentsAreSimpleReferences) + .matching(node -> node.getGroupingSets().getGroupingSetCount() <= 1) + .matching(PushAggregationIntoTableScan::hasNoMasks) + .with(source().matching(tableScan().capturedAs(TABLE_SCAN))); + + private final Metadata metadata; + + public PushAggregationIntoTableScan(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + private static boolean allArgumentsAreSimpleReferences(AggregationNode node) + { + return node.getAggregations() + .values().stream() + .flatMap(aggregation -> aggregation.getArguments().stream()) + .allMatch(SymbolReference.class::isInstance); + } + + private static boolean hasNoMasks(AggregationNode node) + { + return !node.getAggregations() + .values().stream() + .map(aggregation -> aggregation.getMask().isPresent()) + .anyMatch(isMaskPresent -> isMaskPresent); + } + + @Override + public Result apply(AggregationNode node, Captures captures, Context context) + { + TableScanNode tableScan = captures.get(TABLE_SCAN); + Map assignments = tableScan.getAssignments() + .entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue)); + + List> aggregations = node.getAggregations() + .entrySet().stream() + .collect(toImmutableList()); + + List aggregateFunctions = aggregations.stream() + .map(Entry::getValue) + .map(aggregation -> toAggregateFunction(context, aggregation)) + .collect(toImmutableList()); + + List aggregationOutputSymbols = aggregations.stream() + .map(Entry::getKey) + .collect(toImmutableList()); + + GroupingSetDescriptor groupingSets = node.getGroupingSets(); + + List groupByColumns = groupingSets.getGroupingKeys().stream() + .map(groupByColumn -> assignments.get(groupByColumn.getName())) + .collect(toImmutableList()); + + Optional> aggregationPushdownResult = metadata.applyAggregation( + context.getSession(), + tableScan.getTable(), + aggregateFunctions, + assignments, + ImmutableList.of(groupByColumns)); + + if (aggregationPushdownResult.isEmpty()) { + return Result.empty(); + } + + AggregationApplicationResult result = aggregationPushdownResult.get(); + + // The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations. + ImmutableList.Builder newScanOutputs = new ImmutableList.Builder<>(); + newScanOutputs.addAll(tableScan.getOutputSymbols()); + + ImmutableBiMap.Builder newScanAssignments = new ImmutableBiMap.Builder<>(); + newScanAssignments.putAll(tableScan.getAssignments()); + + Map variableMappings = new HashMap<>(); + + for (Assignment assignment : result.getAssignments()) { + Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType()); + + newScanOutputs.add(symbol); + newScanAssignments.put(symbol, assignment.getColumn()); + variableMappings.put(assignment.getVariable(), symbol); + } + + List newProjections = result.getProjections().stream() + .map(expression -> ConnectorExpressionTranslator.translate(expression, variableMappings, new LiteralEncoder(metadata))) + .collect(toImmutableList()); + + verify(aggregationOutputSymbols.size() == newProjections.size()); + + Assignments.Builder assignmentBuilder = Assignments.builder(); + IntStream.range(0, aggregationOutputSymbols.size()) + .forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index))); + + ImmutableBiMap scanAssignments = newScanAssignments.build(); + ImmutableBiMap columnHandleToSymbol = scanAssignments.inverse(); + // projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references + groupingSets.getGroupingKeys() + .forEach(groupBySymbol -> { + // if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to + // new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle. + ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName()); + ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle); + assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference()); + }); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + TableScanNode.newInstance( + context.getIdAllocator().getNextId(), + result.getHandle(), + newScanOutputs.build(), + scanAssignments), + assignmentBuilder.build())); + } + + private AggregateFunction toAggregateFunction(Context context, AggregationNode.Aggregation aggregation) + { + Signature signature = aggregation.getResolvedFunction().getSignature(); + + ImmutableList.Builder arguments = new ImmutableList.Builder<>(); + for (int i = 0; i < aggregation.getArguments().size(); i++) { + SymbolReference argument = (SymbolReference) aggregation.getArguments().get(i); + arguments.add(new Variable(argument.getName(), metadata.getType(signature.getArgumentTypes().get(i)))); + } + + Optional orderingScheme = aggregation.getOrderingScheme(); + Optional> sortBy = orderingScheme.map(orderings -> + orderings.getOrderBy().stream() + .map(orderBy -> new SortItem( + orderBy.getName(), + SortOrder.valueOf(orderings.getOrderings().get(orderBy).name()))) + .collect(toImmutableList())); + + Optional filter = aggregation.getFilter() + .map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol))); + + return new AggregateFunction( + signature.getName(), + metadata.getType(signature.getReturnType()), + arguments.build(), + sortBy.orElse(ImmutableList.of()), + aggregation.isDistinct(), + filter); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionIntoTableScan.java index 54babc519efd7..31add6c71d849 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionIntoTableScan.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionIntoTableScan.java @@ -20,6 +20,7 @@ import io.prestosql.matching.Pattern; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ProjectionApplicationResult; import io.prestosql.spi.expression.ConnectorExpression; @@ -119,7 +120,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) List newScanOutputs = new ArrayList<>(); Map newScanAssignments = new HashMap<>(); Map variableMappings = new HashMap<>(); - for (ProjectionApplicationResult.Assignment assignment : result.get().getAssignments()) { + for (Assignment assignment : result.get().getAssignments()) { Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType()); newScanOutputs.add(symbol); diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index d4b7f17ff3500..bd686f74d1985 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -23,6 +23,8 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.BlockEncoding; import io.prestosql.spi.block.BlockEncodingSerde; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.AggregationApplicationResult; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -447,6 +449,17 @@ public Optional applySample(Session session, TableHandle table, Sam return Optional.empty(); } + @Override + public Optional> applyAggregation( + Session session, + TableHandle table, + List aggregations, + Map assignments, + List> groupingSets) + { + return Optional.empty(); + } + // // Roles and Grants // diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java index a63bd6332c131..a761b0a355718 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.java @@ -22,13 +22,13 @@ import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; import io.prestosql.plugin.tpch.TpchColumnHandle; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorTableHandle; import io.prestosql.spi.connector.ConnectorTransactionHandle; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.expression.Constant; diff --git a/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 4c39dfebc4584..3a1eb0a5e5974 100644 --- a/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -15,6 +15,8 @@ import io.airlift.slice.Slice; import io.prestosql.spi.classloader.ThreadContextClassLoader; +import io.prestosql.spi.connector.AggregateFunction; +import io.prestosql.spi.connector.AggregationApplicationResult; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -695,6 +697,19 @@ public Optional applySample(ConnectorSession session, Conn } } + @Override + public Optional> applyAggregation( + ConnectorSession session, + ConnectorTableHandle table, + List aggregates, + Map assignments, + List> groupingSets) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.applyAggregation(session, table, aggregates, assignments, groupingSets); + } + } + @Override public void validateScan(ConnectorSession session, ConnectorTableHandle handle) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/AggregateFunction.java b/presto-spi/src/main/java/io/prestosql/spi/connector/AggregateFunction.java new file mode 100644 index 0000000000000..eca20e8c0c547 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/AggregateFunction.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.connector; + +import io.prestosql.spi.expression.ConnectorExpression; +import io.prestosql.spi.type.Type; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.StringJoiner; + +import static java.util.Objects.requireNonNull; + +public class AggregateFunction +{ + private final String functionName; + private final Type outputType; + private final List inputs; + private final List sortItems; + private final boolean isDistinct; + private final Optional filter; + + public AggregateFunction( + String aggregateFunctionName, + Type outputType, + List inputs, + List sortItems, + boolean isDistinct, + Optional filter) + { + this.functionName = requireNonNull(aggregateFunctionName, "name is null"); + this.outputType = requireNonNull(outputType, "outputType is null"); + requireNonNull(inputs, "inputs is null"); + requireNonNull(sortItems, "sortOrder is null"); + this.inputs = List.copyOf(inputs); + this.sortItems = List.copyOf(sortItems); + this.isDistinct = isDistinct; + this.filter = requireNonNull(filter, "filter is null"); + } + + public String getFunctionName() + { + return functionName; + } + + public List getInputs() + { + return inputs; + } + + public Type getOutputType() + { + return outputType; + } + + public List getSortItems() + { + return sortItems; + } + + public boolean isDistinct() + { + return isDistinct; + } + + public Optional getFilter() + { + return filter; + } + + @Override + public String toString() + { + return new StringJoiner(", ", AggregateFunction.class.getSimpleName() + "[", "]") + .add("aggregationName='" + functionName + "'") + .add("inputs=" + inputs) + .add("outputType=" + outputType) + .add("sortOrder=" + sortItems) + .add("isDistinct=" + isDistinct) + .add("filter=" + filter) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + AggregateFunction that = (AggregateFunction) o; + return isDistinct == that.isDistinct && + Objects.equals(functionName, that.functionName) && + Objects.equals(inputs, that.inputs) && + Objects.equals(outputType, that.outputType) && + Objects.equals(sortItems, that.sortItems) && + Objects.equals(filter, that.filter); + } + + @Override + public int hashCode() + { + return Objects.hash(functionName, inputs, outputType, sortItems, isDistinct, filter); + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/AggregationApplicationResult.java b/presto-spi/src/main/java/io/prestosql/spi/connector/AggregationApplicationResult.java new file mode 100644 index 0000000000000..943361e42b51c --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/AggregationApplicationResult.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.connector; + +import io.prestosql.spi.expression.ConnectorExpression; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class AggregationApplicationResult +{ + private final T handle; + private final List projections; + private final List assignments; + private final Map groupingColumnMapping; + + public AggregationApplicationResult( + T handle, + List projections, + List assignments, + Map groupingColumnMapping) + { + this.handle = requireNonNull(handle, "handle is null"); + requireNonNull(groupingColumnMapping, "goupingSetMapping is null"); + requireNonNull(projections, "projections is null"); + requireNonNull(assignments, "assignment is null"); + this.groupingColumnMapping = Map.copyOf(groupingColumnMapping); + this.projections = List.copyOf(projections); + this.assignments = List.copyOf(assignments); + } + + public T getHandle() + { + return handle; + } + + public List getProjections() + { + return projections; + } + + public List getAssignments() + { + return assignments; + } + + public Map getGroupingColumnMapping() + { + return groupingColumnMapping; + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/Assignment.java b/presto-spi/src/main/java/io/prestosql/spi/connector/Assignment.java new file mode 100644 index 0000000000000..7a497beab9bbd --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/Assignment.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.connector; + +import io.prestosql.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +public class Assignment +{ + private final String variable; + private final ColumnHandle column; + private final Type type; + + public Assignment(String variable, ColumnHandle column, Type type) + { + this.variable = requireNonNull(variable, "variable is null"); + this.column = requireNonNull(column, "column is null"); + this.type = requireNonNull(type, "type is null"); + } + + public String getVariable() + { + return variable; + } + + public ColumnHandle getColumn() + { + return column; + } + + public Type getType() + { + return type; + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java index 0fac96b45821b..8decc2a4bd64e 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java @@ -16,6 +16,8 @@ import io.airlift.slice.Slice; import io.prestosql.spi.PrestoException; import io.prestosql.spi.expression.ConnectorExpression; +import io.prestosql.spi.expression.Constant; +import io.prestosql.spi.expression.Variable; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.security.GrantInfo; import io.prestosql.spi.security.PrestoPrincipal; @@ -840,6 +842,89 @@ default Optional applySample(ConnectorSession session, Con return Optional.empty(); } + /** + * Attempt to push down the aggregates into the table. + *

+ * Connectors can indicate whether they don't support aggregate pushdown or that the action had no effect + * by returning {@link Optional#empty()}. Connectors should expect this method may be called multiple times. + *

+ * Note: it's critical for connectors to return {@link Optional#empty()} if calling this method has no effect for that + * invocation, even if the connector generally supports pushdown. Doing otherwise can cause the optimizer + * to loop indefinitely. + *

+ * If the method returns a result, the list of assignments in the result will be merged with existing assignments. The projections + * returned by the method must have the same order as the given input list of aggregates. + *

+ * As an example, given the following plan: + * + *
+     *  - aggregation  (GROUP BY c)
+     *          variable0 = agg_fn1(a)
+     *          variable1 = agg_fn2(b, 2)
+     *          variable2 = c
+     *          - scan (TH0)
+     *              a = CH0
+     *              b = CH1
+     *              c = CH2
+     * 
+ *

+ * The optimizer would call this method with the following arguments: + * + *

+     *      handle = TH0
+     *      aggregates = [
+     *              { functionName=agg_fn1, outputType = «some presto type» inputs = [{@link Variable} a]} ,
+     *              { functionName=agg_fn2, outputType = «some presto type» inputs = [{@link Variable} b, {@link Constant} 2]}
+     *      ]
+     *      groupingSets=[[{@link ColumnHandle} CH2]]
+     *      assignments = {a = CH0, b = CH1, c = CH2}
+     * 
+ *

+ * + * Assuming the connector knows how to handle {@code agg_fn1(...)} and {@code agg_fn2(...)}, it would return: + *
+     *
+     * {@link AggregationApplicationResult} {
+     *      handle = TH1
+     *      projections = [{@link Variable} synthetic_name0, {@link Variable} synthetic_name1] -- The order in the list must be same as input list of aggregates
+     *      assignments = {
+     *          synthetic_name0 = CH3 (synthetic column for agg_fn1(a))
+     *          synthetic_name1 = CH4 (synthetic column for agg_fn2(b,2))
+     *      }
+     * }
+     * 
+ * + * if the connector only knows how to handle {@code agg_fn1(...)}, but not {@code agg_fn2}, it should return {@link Optional#empty()}. + * + *

+ * Another example is where the connector wants to handle the aggregate function by pointing to a pre-materialized table. + * In this case the input can stay same as in the above example and the connector can return + *

+     * {@link AggregationApplicationResult} {
+     *      handle = TH1 (could capture information about which pre-materialized table to use)
+     *      projections = [{@link Variable} synthetic_name0, {@link Variable} synthetic_name1] -- The order in the list must be same as input list of aggregates
+     *      assignments = {
+     *          synthetic_name0 = CH3 (reference to the column in pre-materialized table that has agg_fn1(a) calculated)
+     *          synthetic_name1 = CH4 (reference to the column in pre-materialized table that has agg_fn2(b,2) calculated)
+     *          synthetic_name2 = CH5 (reference to the column in pre-materialized table that has the group by column c)
+     *      }
+     *      groupingColumnMapping = {
+     *          CH2 -> CH5 (CH2 in the original assignment should now be replaced by CH5 in the new assignment)
+     *      }
+     * }
+     * 
+ *

+ */ + default Optional> applyAggregation( + ConnectorSession session, + ConnectorTableHandle handle, + List aggregates, + Map assignments, + List> groupingSets) + { + return Optional.empty(); + } + /** * Allows the connector to reject the table scan produced by the planner. *

diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/ProjectionApplicationResult.java b/presto-spi/src/main/java/io/prestosql/spi/connector/ProjectionApplicationResult.java index cdd6d9a78fe2b..645a7ff953319 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/ProjectionApplicationResult.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/ProjectionApplicationResult.java @@ -14,7 +14,6 @@ package io.prestosql.spi.connector; import io.prestosql.spi.expression.ConnectorExpression; -import io.prestosql.spi.type.Type; import java.util.List; @@ -47,33 +46,4 @@ public List getAssignments() { return assignments; } - - public static class Assignment - { - private final String variable; - private final ColumnHandle column; - private final Type type; - - public Assignment(String variable, ColumnHandle column, Type type) - { - this.variable = requireNonNull(variable, "variable is null"); - this.column = requireNonNull(column, "column is null"); - this.type = requireNonNull(type, "type is null"); - } - - public String getVariable() - { - return variable; - } - - public ColumnHandle getColumn() - { - return column; - } - - public Type getType() - { - return type; - } - } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/SortItem.java b/presto-spi/src/main/java/io/prestosql/spi/connector/SortItem.java new file mode 100644 index 0000000000000..e7d4161c9bcc5 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/SortItem.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.connector; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class SortItem +{ + private final String name; + private final SortOrder sortOrder; + + public SortItem(String name, SortOrder sortOrder) + { + this.name = requireNonNull(name, "name is null"); + this.sortOrder = requireNonNull(sortOrder, "name is null"); + } + + public String getName() + { + return name; + } + + public SortOrder getSortOrder() + { + return sortOrder; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SortItem sortItem = (SortItem) o; + return name.equals(sortItem.name) && sortOrder == sortItem.sortOrder; + } + + @Override + public int hashCode() + { + return Objects.hash(name, sortOrder); + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/SortOrder.java b/presto-spi/src/main/java/io/prestosql/spi/connector/SortOrder.java new file mode 100644 index 0000000000000..e83c91f1bf2c0 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/SortOrder.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.connector; + +public enum SortOrder +{ + ASC_NULLS_FIRST(true, true), + ASC_NULLS_LAST(true, false), + DESC_NULLS_FIRST(false, true), + DESC_NULLS_LAST(false, false); + + private final boolean ascending; + private final boolean nullsFirst; + + SortOrder(boolean ascending, boolean nullsFirst) + { + this.ascending = ascending; + this.nullsFirst = nullsFirst; + } + + public boolean isAscending() + { + return ascending; + } + + public boolean isNullsFirst() + { + return nullsFirst; + } +} diff --git a/presto-thrift/src/main/java/io/prestosql/plugin/thrift/ThriftMetadata.java b/presto-thrift/src/main/java/io/prestosql/plugin/thrift/ThriftMetadata.java index f061b582066f2..7d673f937b30f 100644 --- a/presto-thrift/src/main/java/io/prestosql/plugin/thrift/ThriftMetadata.java +++ b/presto-thrift/src/main/java/io/prestosql/plugin/thrift/ThriftMetadata.java @@ -28,6 +28,7 @@ import io.prestosql.plugin.thrift.api.PrestoThriftService; import io.prestosql.plugin.thrift.api.PrestoThriftServiceException; import io.prestosql.spi.PrestoException; +import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.connector.ConnectorMetadata; @@ -39,7 +40,6 @@ import io.prestosql.spi.connector.Constraint; import io.prestosql.spi.connector.ConstraintApplicationResult; import io.prestosql.spi.connector.ProjectionApplicationResult; -import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; import io.prestosql.spi.connector.TableNotFoundException;