Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Pass columns required by upstream query for group-by pushdown #9698

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets);
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns);

Optional<JoinApplicationResult<TableHandle>> applyJoin(
Session session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,8 @@ public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
// Global aggregation is represented by [[]]
checkArgument(!groupingSets.isEmpty(), "No grouping sets provided");
Expand All @@ -1575,7 +1576,7 @@ public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
}

ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets)
return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets, requiredColumns)
.map(result -> {
verifyProjection(table, result.getProjections(), result.getAssignments(), aggregations.size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -146,12 +148,16 @@ public static Optional<PlanNode> pushAggregationIntoTableScan(
.map(groupByColumn -> assignments.get(groupByColumn.getName()))
.collect(toImmutableList());

// TODO FIXME How to get columns required by upstream query?
Set<String> requiredColumns = Collections.emptySet();
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still unsure how to properly propagate columns required by the upstream query.


Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = metadata.applyAggregation(
context.getSession(),
tableScan.getTable(),
aggregateFunctions,
assignments,
ImmutableList.of(groupByColumns));
ImmutableList.of(groupByColumns),
requiredColumns);

if (aggregationPushdownResult.isEmpty()) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ConnectorTableHandle handle,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
return applyAggregation.apply(session, handle, aggregates, assignments, groupingSets);
return applyAggregation.apply(session, handle, aggregates, assignments, groupingSets, requiredColumns);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ Optional<AggregationApplicationResult<ConnectorTableHandle>> apply(
ConnectorTableHandle handle,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets);
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns);
}

@FunctionalInterface
Expand Down Expand Up @@ -270,7 +271,7 @@ public static final class Builder
private BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle = defaultGetTableHandle();
private Function<SchemaTableName, List<ColumnMetadata>> getColumns = defaultGetColumns();
private ApplyProjection applyProjection = (session, handle, projections, assignments) -> Optional.empty();
private ApplyAggregation applyAggregation = (session, handle, aggregates, assignments, groupingSets) -> Optional.empty();
private ApplyAggregation applyAggregation = (session, handle, aggregates, assignments, groupingSets, requiredColumns) -> Optional.empty();
private ApplyJoin applyJoin = (session, joinType, left, right, joinConditions, leftAssignments, rightAssignments) -> Optional.empty();
private BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout = defaultGetInsertLayout();
private BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout = defaultGetNewTableLayout();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
TableHandle table,
List<AggregateFunction> aggregations,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ protected Optional<LocalQueryRunner> createLocalQueryRunner()
TEST_CATALOG.getCatalogName(),
MockConnectorFactory.builder()
.withApplyAggregation(
(session, handle, aggregates, assignments, groupingSets) -> {
(session, handle, aggregates, assignments, groupingSets, requiredColumns) -> {
if (testApplyAggregation != null) {
return testApplyAggregation.apply(session, handle, aggregates, assignments, groupingSets);
return testApplyAggregation.apply(session, handle, aggregates, assignments, groupingSets, requiredColumns);
}
return Optional.empty();
})
Expand Down Expand Up @@ -115,7 +115,7 @@ public void testDoesNotFireIfNoTableScan()
public void testNoEffect()
{
AtomicInteger applyCallCounter = new AtomicInteger();
testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
testApplyAggregation = (session, handle, aggregates, assignments, groupingSets, requiredColumns) -> {
applyCallCounter.incrementAndGet();
return Optional.empty();
};
Expand All @@ -142,7 +142,7 @@ public void testPushDistinct()
AtomicReference<Map<String, ColumnHandle>> applyAssignments = new AtomicReference<>();
AtomicReference<List<List<ColumnHandle>>> applyGroupingSets = new AtomicReference<>();

testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
testApplyAggregation = (session, handle, aggregates, assignments, groupingSets, requiredColumns) -> {
applyCallCounter.incrementAndGet();
applyAggregates.set(List.copyOf(aggregates));
applyAssignments.set(Map.copyOf(assignments));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,8 @@ default Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggreg
ConnectorTableHandle handle,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,11 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ConnectorTableHandle table,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.applyAggregation(session, table, aggregates, assignments, groupingSets);
return delegate.applyAggregation(session, table, aggregates, assignments, groupingSets, requiredColumns);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ConnectorTableHandle table,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
if (!isAggregationPushdownEnabled(session)) {
return Optional.empty();
Expand Down Expand Up @@ -287,6 +288,7 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
groupKey,
tableColumns))
.orElse(groupKey -> {}))
.filter(column -> requiredColumns.isEmpty() || requiredColumns.contains(column.getColumnName()))
.forEach(newColumns::add);

for (AggregateFunction aggregate : aggregates) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -249,7 +250,8 @@ public void testAggregationPushdownForTableHandle()
handle,
ImmutableList.of(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty())),
ImmutableMap.of(),
ImmutableList.of(ImmutableList.of(groupByColumn)));
ImmutableList.of(ImmutableList.of(groupByColumn)),
Collections.emptySet());

ConnectorTableHandle baseTableHandle = metadata.getTableHandle(session, new SchemaTableName("example", "numbers"));
Optional<AggregationApplicationResult<ConnectorTableHandle>> aggregationResult = applyAggregation.apply(baseTableHandle);
Expand Down Expand Up @@ -371,7 +373,8 @@ private JdbcTableHandle applyCountAggregation(ConnectorSession session, Connecto
tableHandle,
ImmutableList.of(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty())),
ImmutableMap.of(),
groupByColumns);
groupByColumns,
Collections.emptySet());
assertThat(aggResult).isPresent();
return (JdbcTableHandle) aggResult.get().getHandle();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -238,7 +239,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ConnectorTableHandle table,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
// TODO support aggregation pushdown
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -259,7 +260,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ConnectorTableHandle table,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
// TODO support aggregation pushdown
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -334,7 +335,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
ConnectorTableHandle handle,
List<AggregateFunction> aggregates,
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets)
List<List<ColumnHandle>> groupingSets,
Set<String> requiredColumns)
{
if (!isAggregationPushdownEnabled(session)) {
return Optional.empty();
Expand Down