Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPI and engine changes for aggregation pushdown #3697

Merged
merged 1 commit into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -33,7 +34,6 @@
import io.prestosql.spi.connector.ConstraintApplicationResult;
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.SystemTable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.airlift.log.Logger;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorMetadata;
Expand All @@ -37,7 +38,6 @@
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.NotFoundException;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.TableNotFoundException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
Expand All @@ -68,7 +69,6 @@
import io.prestosql.spi.connector.DiscretePredicates;
import io.prestosql.spi.connector.InMemoryRecordSet;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaNotFoundException;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import io.prestosql.spi.Page;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -84,7 +85,6 @@
import io.prestosql.spi.connector.ConstraintApplicationResult;
import io.prestosql.spi.connector.DiscretePredicates;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.RecordCursor;
import io.prestosql.spi.connector.RecordPageSource;
import io.prestosql.spi.connector.SchemaTableName;
Expand Down Expand Up @@ -3057,7 +3057,7 @@ public void testApplyProjection()
assertProjectionResult(projectionResult, false, expectedProjections, expectedAssignments);

// Round-2: input projections [symbol_2.int0 and onelevelrow0#f_int0]. Virtual handle is reused.
ProjectionApplicationResult.Assignment newlyCreatedColumn = getOnlyElement(projectionResult.get().getAssignments().stream()
Assignment newlyCreatedColumn = getOnlyElement(projectionResult.get().getAssignments().stream()
.filter(handle -> handle.getVariable().equals("onelevelrow0#f_int0"))
.collect(toList()));
inputAssignments = ImmutableMap.<String, ColumnHandle>builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.airlift.slice.Slice;
import io.prestosql.plugin.kudu.properties.KuduTableProperties;
import io.prestosql.plugin.kudu.properties.PartitionDesign;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorInsertTableHandle;
Expand All @@ -33,7 +34,6 @@
import io.prestosql.spi.connector.ConstraintApplicationResult;
import io.prestosql.spi.connector.NotFoundException;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult.Assignment;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.expression.ConnectorExpression;
Expand Down
9 changes: 9 additions & 0 deletions presto-main/src/main/java/io/prestosql/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.BlockEncoding;
import io.prestosql.spi.block.BlockEncodingSerde;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.AggregationApplicationResult;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
Expand Down Expand Up @@ -353,6 +355,13 @@ public interface Metadata

Optional<TableHandle> applySample(Session session, TableHandle table, SampleType sampleType, double sampleRatio);

Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
Session session,
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets);

default void validateScan(Session session, TableHandle table) {}

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
import io.prestosql.spi.block.SingleMapBlockEncoding;
import io.prestosql.spi.block.SingleRowBlockEncoding;
import io.prestosql.spi.block.VariableWidthBlockEncoding;
import io.prestosql.spi.connector.AggregateFunction;
import io.prestosql.spi.connector.AggregationApplicationResult;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
Expand Down Expand Up @@ -1108,6 +1111,57 @@ public Optional<TableHandle> applySample(Session session, TableHandle table, Sam
Optional.empty()));
}

@Override
public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
Session session,
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
{
CatalogName catalogName = table.getCatalogName();
ConnectorMetadata metadata = getMetadata(session, catalogName);

if (metadata.usesLegacyTableLayouts()) {
return Optional.empty();
}

ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets)
.map(result -> {
verifyProjection(table, result.getProjections(), result.getAssignments(), aggregations.size());

return new AggregationApplicationResult<>(
new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()),
result.getProjections(),
result.getAssignments(),
result.getGroupingColumnMapping());
});
}

private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
{
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));
assignments.forEach(assignment -> requireNonNull(assignment, "one of the assignments is null"));

verify(
expectedProjectionSize == projections.size(),
"ConnectorMetadata returned invalid number of projections: %s instead of %s for %s",
projections.size(),
expectedProjectionSize,
table);

Set<String> assignedVariables = assignments.stream()
.map(Assignment::getVariable)
.collect(toImmutableSet());
projections.stream()
.flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream())
.map(Variable::getName)
.filter(variableName -> !assignedVariables.contains(variableName))
.findAny()
.ifPresent(variableName -> { throw new IllegalStateException("Unbound variable: " + variableName); });
}

@Override
public void validateScan(Session session, TableHandle table)
{
Expand Down Expand Up @@ -1146,24 +1200,7 @@ public Optional<ProjectionApplicationResult<TableHandle>> applyProjection(Sessio
ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyProjection(connectorSession, table.getConnectorHandle(), projections, assignments)
.map(result -> {
result.getProjections().forEach(projection -> requireNonNull(projection, "one of the projections is null"));
result.getAssignments().forEach(assignment -> requireNonNull(assignment, "one of the assignments is null"));

verify(
projections.size() == result.getProjections().size(),
"ConnectorMetadata returned invalid number of projections: %s instead of %s for %s",
result.getProjections().size(),
projections.size(),
table);

Set<String> assignedVariables = result.getAssignments().stream()
.map(ProjectionApplicationResult.Assignment::getVariable)
.collect(toImmutableSet());
result.getProjections().stream()
.flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream())
.map(Variable::getName)
.filter(variableName -> !assignedVariables.contains(variableName))
.findAny().ifPresent(variableName -> { throw new IllegalStateException("Unbound variable: " + variableName); });
verifyProjection(table, result.getProjections(), result.getAssignments(), projections.size());

return new ProjectionApplicationResult<>(
new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import io.prestosql.sql.planner.iterative.rule.PruneUnnestSourceColumns;
import io.prestosql.sql.planner.iterative.rule.PruneValuesColumns;
import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns;
import io.prestosql.sql.planner.iterative.rule.PushAggregationIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushDeleteIntoConnector;
import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughFilter;
Expand Down Expand Up @@ -508,6 +509,7 @@ public PlanOptimizers(
.add(new PushLimitIntoTableScan(metadata))
.add(new PushPredicateIntoTableScan(metadata, typeAnalyzer))
.add(new PushSampleIntoTableScan(metadata))
.add(new PushAggregationIntoTableScan(metadata))
.build()),
new IterativeOptimizer(
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.TableHandle;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
Expand Down Expand Up @@ -88,7 +89,7 @@ public static Optional<PlanNode> pruneColumns(Metadata metadata, TypeProvider ty
handle = result.get().getHandle();

Map<String, ColumnHandle> assignments = result.get().getAssignments().stream()
.collect(toImmutableMap(ProjectionApplicationResult.Assignment::getVariable, ProjectionApplicationResult.Assignment::getColumn));
.collect(toImmutableMap(Assignment::getVariable, Assignment::getColumn));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving out Assignment as a top level class touches a lot of files, that are otherwise unrelated to the change.
Can you please extract this to a preparatory commit?


ImmutableMap.Builder<Symbol, ColumnHandle> builder = ImmutableMap.builder();
for (int i = 0; i < newOutputs.size(); i++) {
Expand Down
Loading