Skip to content

Commit

Permalink
Add support for TopN pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Parth-Brahmbhatt authored and martint committed Aug 5, 2020
1 parent f4294ab commit 3602068
Show file tree
Hide file tree
Showing 12 changed files with 434 additions and 10 deletions.
9 changes: 9 additions & 0 deletions presto-main/src/main/java/io/prestosql/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
Expand Down Expand Up @@ -367,6 +369,13 @@ Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets);

Optional<TopNApplicationResult<TableHandle>> applyTopN(
Session session,
TableHandle handle,
long topNCount,
List<SortItem> sortItems,
Map<String, ColumnHandle> assignments);

default void validateScan(Session session, TableHandle table) {}

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.spi.function.InvocationConvention;
Expand Down Expand Up @@ -1160,6 +1162,28 @@ public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
});
}

@Override
public Optional<TopNApplicationResult<TableHandle>> applyTopN(
Session session,
TableHandle table,
long topNCount,
List<SortItem> sortItems,
Map<String, ColumnHandle> assignments)
{
CatalogName catalogName = table.getCatalogName();
ConnectorMetadata metadata = getMetadata(session, catalogName);

if (metadata.usesLegacyTableLayouts()) {
return Optional.empty();
}

ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyTopN(connectorSession, table.getConnectorHandle(), topNCount, sortItems, assignments)
.map(result -> new TopNApplicationResult<>(
new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()),
result.isTopNGuaranteed()));
}

private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
{
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,13 @@ public static SortOrder sortItemToSortOrder(SortItem sortItem)
}
return SortOrder.DESC_NULLS_LAST;
}

public List<io.prestosql.spi.connector.SortItem> toSortItems()
{
return getOrderBy().stream()
.map(symbol -> new io.prestosql.spi.connector.SortItem(
symbol.getName(),
io.prestosql.spi.connector.SortOrder.valueOf(getOrdering(symbol).name())))
.collect(toImmutableList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
import io.prestosql.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
import io.prestosql.sql.planner.iterative.rule.PushSampleIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import io.prestosql.sql.planner.iterative.rule.PushTopNIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughProject;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughUnion;
Expand Down Expand Up @@ -623,7 +624,8 @@ public PlanOptimizers(
new CreatePartialTopN(),
new PushTopNThroughProject(),
new PushTopNThroughOuterJoin(),
new PushTopNThroughUnion())));
new PushTopNThroughUnion(),
new PushTopNIntoTableScan(metadata))));
builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SortOrder;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.sql.planner.ConnectorExpressionTranslator;
Expand Down Expand Up @@ -200,12 +199,7 @@ private static AggregateFunction toAggregateFunction(Context context, Aggregatio
}

Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
Optional<List<SortItem>> sortBy = orderingScheme.map(orderings ->
orderings.getOrderBy().stream()
.map(orderBy -> new SortItem(
orderBy.getName(),
SortOrder.valueOf(orderings.getOrderings().get(orderBy).name())))
.collect(toImmutableList()));
Optional<List<SortItem>> sortBy = orderingScheme.map(OrderingScheme::toSortItems);

Optional<ConnectorExpression> filter = aggregation.getFilter()
.map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol)));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TopNNode;

import java.util.List;
import java.util.Map;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.sql.planner.plan.Patterns.source;
import static io.prestosql.sql.planner.plan.Patterns.tableScan;
import static io.prestosql.sql.planner.plan.Patterns.topN;

public class PushTopNIntoTableScan
implements Rule<TopNNode>
{
private static final Capture<TableScanNode> TABLE_SCAN = newCapture();

private static final Pattern<TopNNode> PATTERN = topN().with(source().matching(
tableScan().capturedAs(TABLE_SCAN)));

private final Metadata metadata;

public PushTopNIntoTableScan(Metadata metadata)
{
this.metadata = metadata;
}

@Override
public Pattern<TopNNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(TopNNode topNNode, Captures captures, Context context)
{
TableScanNode tableScan = captures.get(TABLE_SCAN);

long topNCount = topNNode.getCount();
List<SortItem> sortItems = topNNode.getOrderingScheme().toSortItems();

Map<String, ColumnHandle> assignments = tableScan.getAssignments()
.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));

return metadata.applyTopN(context.getSession(), tableScan.getTable(), topNCount, sortItems, assignments)
.map(result -> {
PlanNode node = TableScanNode.newInstance(
context.getIdAllocator().getNextId(),
result.getHandle(),
tableScan.getOutputSymbols(),
tableScan.getAssignments());

if (!result.isTopNGuaranteed()) {
node = new TopNNode(topNNode.getId(), node, topNNode.getCount(), topNNode.getOrderingScheme(), TopNNode.Step.FINAL);
}
return Result.ofPlanNode(node);
})
.orElseGet(Result::empty);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.eventlistener.EventListener;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.security.PrestoPrincipal;
Expand Down Expand Up @@ -75,6 +77,7 @@ public class MockConnectorFactory
private final BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle;
private final Function<SchemaTableName, List<ColumnMetadata>> getColumns;
private final ApplyProjection applyProjection;
private final ApplyTopN applyTopN;
private final BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout;
private final BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout;
private final Supplier<Iterable<EventListener>> eventListeners;
Expand All @@ -87,6 +90,7 @@ private MockConnectorFactory(
BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle,
Function<SchemaTableName, List<ColumnMetadata>> getColumns,
ApplyProjection applyProjection,
ApplyTopN applyTopN,
BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout,
BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout,
Supplier<Iterable<EventListener>> eventListeners,
Expand All @@ -98,6 +102,7 @@ private MockConnectorFactory(
this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null");
this.getColumns = getColumns;
this.applyProjection = applyProjection;
this.applyTopN = requireNonNull(applyTopN, "applyTopN is null");
this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null");
this.getNewTableLayout = requireNonNull(getNewTableLayout, "getNewTableLayout is null");
this.eventListeners = requireNonNull(eventListeners, "eventListeners is null");
Expand All @@ -119,7 +124,7 @@ public ConnectorHandleResolver getHandleResolver()
@Override
public Connector create(String catalogName, Map<String, String> config, ConnectorContext context)
{
return new MockConnector(context, listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
return new MockConnector(context, listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, applyTopN, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
}

public static Builder builder()
Expand All @@ -133,6 +138,12 @@ public interface ApplyProjection
Optional<ProjectionApplicationResult<ConnectorTableHandle>> apply(ConnectorSession session, ConnectorTableHandle handle, List<ConnectorExpression> projections, Map<String, ColumnHandle> assignments);
}

@FunctionalInterface
public interface ApplyTopN
{
Optional<TopNApplicationResult<ConnectorTableHandle>> apply(ConnectorSession session, ConnectorTableHandle handle, long topNCount, List<SortItem> sortItems, Map<String, ColumnHandle> assignments);
}

@FunctionalInterface
public interface ListRoleGrants
{
Expand All @@ -149,6 +160,7 @@ public static class MockConnector
private final BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle;
private final Function<SchemaTableName, List<ColumnMetadata>> getColumns;
private final ApplyProjection applyProjection;
private final ApplyTopN applyTopN;
private final BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout;
private final BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout;
private final Supplier<Iterable<EventListener>> eventListeners;
Expand All @@ -162,6 +174,7 @@ private MockConnector(
BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle,
Function<SchemaTableName, List<ColumnMetadata>> getColumns,
ApplyProjection applyProjection,
ApplyTopN applyTopN,
BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout,
BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout,
Supplier<Iterable<EventListener>> eventListeners,
Expand All @@ -174,6 +187,7 @@ private MockConnector(
this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null");
this.getColumns = requireNonNull(getColumns, "getColumns is null");
this.applyProjection = requireNonNull(applyProjection, "applyProjection is null");
this.applyTopN = requireNonNull(applyTopN, "applyTopN is null");
this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null");
this.getNewTableLayout = requireNonNull(getNewTableLayout, "getNewTableLayout is null");
this.eventListeners = requireNonNull(eventListeners, "eventListeners is null");
Expand Down Expand Up @@ -219,6 +233,12 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
return applyProjection.apply(session, handle, projections, assignments);
}

@Override
public Optional<TopNApplicationResult<ConnectorTableHandle>> applyTopN(ConnectorSession session, ConnectorTableHandle handle, long topNCount, List<SortItem> sortItems, Map<String, ColumnHandle> assignments)
{
return applyTopN.apply(session, handle, topNCount, sortItems, assignments);
}

@Override
public List<String> listSchemaNames(ConnectorSession session)
{
Expand Down Expand Up @@ -414,6 +434,7 @@ public static final class Builder
private BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout = defaultGetNewTableLayout();
private Supplier<Iterable<EventListener>> eventListeners = ImmutableList::of;
private ListRoleGrants roleGrants = defaultRoleAuthorizations();
private ApplyTopN applyTopN = (session, handle, topNCount, sortItems, assignments) -> Optional.empty();

public Builder withListSchemaNames(Function<ConnectorSession, List<String>> listSchemaNames)
{
Expand Down Expand Up @@ -457,6 +478,12 @@ public Builder withApplyProjection(ApplyProjection applyProjection)
return this;
}

public Builder withApplyTopN(ApplyTopN applyTopN)
{
this.applyTopN = applyTopN;
return this;
}

public Builder withGetInsertLayout(BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout)
{
this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null");
Expand Down Expand Up @@ -487,7 +514,7 @@ public Builder withEventListener(Supplier<EventListener> listenerFactory)

public MockConnectorFactory build()
{
return new MockConnectorFactory(listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
return new MockConnectorFactory(listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, applyTopN, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
}

public static Function<ConnectorSession, List<String>> defaultListSchemaNames()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
Expand Down Expand Up @@ -741,4 +743,10 @@ public Optional<ProjectionApplicationResult<TableHandle>> applyProjection(Sessio
{
return Optional.empty();
}

@Override
public Optional<TopNApplicationResult<TableHandle>> applyTopN(Session session, TableHandle handle, long topNFunctions, List<SortItem> sortItems, Map<String, ColumnHandle> assignments)
{
return Optional.empty();
}
}
Loading

0 comments on commit 3602068

Please sign in to comment.