Skip to content

Commit

Permalink
Rewrite partial top-n node to LimitNode or LastNNode
Browse files Browse the repository at this point in the history
  • Loading branch information
rice668 committed Jan 24, 2024
1 parent 5eddf92 commit b345b77
Show file tree
Hide file tree
Showing 42 changed files with 1,983 additions and 161 deletions.
11 changes: 11 additions & 0 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
<artifactId>oshi-core</artifactId>
</dependency>

<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</dependency>

<dependency>
<groupId>com.google.errorprone</groupId>
<artifactId>error_prone_annotations</artifactId>
Expand Down Expand Up @@ -521,6 +526,12 @@
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ public final class SystemSessionProperties
public static final String PAGE_PARTITIONING_BUFFER_POOL_SIZE = "page_partitioning_buffer_pool_size";
public static final String IDLE_WRITER_MIN_DATA_SIZE_THRESHOLD = "idle_writer_min_data_size_threshold";
public static final String CLOSE_IDLE_WRITERS_TRIGGER_DURATION = "close_idle_writers_trigger_duration";
public static final String ALLOW_PUSH_PARTIAL_TOP_N_TO_TABLE_SCAN = "allow_push_partial_top_n_to_table_scan";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1070,7 +1071,11 @@ public SystemSessionProperties(
durationProperty(CLOSE_IDLE_WRITERS_TRIGGER_DURATION,
"The duration after which the writer operator tries to close the idle writers",
new Duration(5, SECONDS),
true));
true),
booleanProperty(ALLOW_PUSH_PARTIAL_TOP_N_TO_TABLE_SCAN,
"Allow push partial sort to table scan.",
false,
false));
}

@Override
Expand Down Expand Up @@ -1204,6 +1209,11 @@ public static boolean isOptimizeMetadataQueries(Session session)
return session.getSystemProperty(OPTIMIZE_METADATA_QUERIES, Boolean.class);
}

public static boolean isPushPartialTopNIntoTableScan(Session session)
{
return session.getSystemProperty(ALLOW_PUSH_PARTIAL_TOP_N_TO_TABLE_SCAN, Boolean.class);
}

public static DataSize getQueryMaxMemory(Session session)
{
return session.getSystemProperty(QUERY_MAX_MEMORY, DataSize.class);
Expand Down
13 changes: 13 additions & 0 deletions core/trino-main/src/main/java/io/trino/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
import io.trino.spi.connector.MaterializedViewFreshness;
import io.trino.spi.connector.PartialSortApplicationResult;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.RelationCommentMetadata;
import io.trino.spi.connector.RelationType;
Expand All @@ -44,6 +45,7 @@
import io.trino.spi.connector.SaveMode;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SortItem;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.connector.SystemTable;
import io.trino.spi.connector.TableColumnsMetadata;
import io.trino.spi.connector.TableFunctionApplicationResult;
Expand Down Expand Up @@ -553,6 +555,17 @@ Optional<TopNApplicationResult<TableHandle>> applyTopN(
List<SortItem> sortItems,
Map<String, ColumnHandle> assignments);

/**
* Attempt to push down partial sort or top n to table scan.
*/
default Optional<PartialSortApplicationResult<TableHandle>> applyPartialSort(
Session session,
TableHandle tableHandle,
Map<ColumnHandle, SortOrder> columnHandleSortOrderMap)
{
return Optional.empty();
}

Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(Session session, TableFunctionHandle handle);

default void validateScan(Session session, TableHandle table) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ private FilterAndProjectOperator(
List<Type> types,
DataSize minOutputPageSize,
int minOutputPageRowCount,
boolean avoidPageMaterialization)
boolean avoidPageMaterialization,
boolean iteratorEnd)
{
AggregatedMemoryContext localAggregatedMemoryContext = newSimpleAggregatedMemoryContext();
LocalMemoryContext outputMemoryContext = localAggregatedMemoryContext.newLocalMemoryContext(FilterAndProjectOperator.class.getSimpleName());
Expand All @@ -64,8 +65,9 @@ private FilterAndProjectOperator(
yieldSignal,
outputMemoryContext,
metrics,
page))
.transformProcessor(processor -> mergePages(types, minOutputPageSize.toBytes(), minOutputPageRowCount, processor, localAggregatedMemoryContext))
page,
iteratorEnd))
.transformProcessor(processor -> mergePages(types, minOutputPageSize.toBytes(), minOutputPageRowCount, processor, localAggregatedMemoryContext, iteratorEnd))
.blocking(() -> memoryTrackingContext.localUserMemoryContext().setBytes(localAggregatedMemoryContext.getBytes()));
}

Expand All @@ -87,15 +89,17 @@ public static OperatorFactory createOperatorFactory(
Supplier<PageProcessor> processor,
List<Type> types,
DataSize minOutputPageSize,
int minOutputPageRowCount)
int minOutputPageRowCount,
boolean iteratorEnd)
{
return createAdapterOperatorFactory(new Factory(
operatorId,
planNodeId,
processor,
types,
minOutputPageSize,
minOutputPageRowCount));
minOutputPageRowCount,
iteratorEnd));
}

private static class Factory
Expand All @@ -108,21 +112,24 @@ private static class Factory
private final DataSize minOutputPageSize;
private final int minOutputPageRowCount;
private boolean closed;
private final boolean iteratorEnd;

private Factory(
int operatorId,
PlanNodeId planNodeId,
Supplier<PageProcessor> processor,
List<Type> types,
DataSize minOutputPageSize,
int minOutputPageRowCount)
int minOutputPageRowCount,
boolean iteratorEnd)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.processor = requireNonNull(processor, "processor is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.minOutputPageSize = requireNonNull(minOutputPageSize, "minOutputPageSize is null");
this.minOutputPageRowCount = minOutputPageRowCount;
this.iteratorEnd = iteratorEnd;
}

@Override
Expand All @@ -138,7 +145,8 @@ public WorkProcessorOperator create(ProcessorContext processorContext, WorkProce
types,
minOutputPageSize,
minOutputPageRowCount,
true);
true,
iteratorEnd);
}

@Override
Expand All @@ -154,7 +162,8 @@ public WorkProcessorOperator createAdapterOperator(ProcessorContext processorCon
types,
minOutputPageSize,
minOutputPageRowCount,
false);
false,
iteratorEnd);
}

@Override
Expand Down Expand Up @@ -184,7 +193,7 @@ public void close()
@Override
public BasicAdapterWorkProcessorOperatorFactory duplicate()
{
return new Factory(operatorId, planNodeId, processor, types, minOutputPageSize, minOutputPageRowCount);
return new Factory(operatorId, planNodeId, processor, types, minOutputPageSize, minOutputPageRowCount, iteratorEnd);
}
}
}
196 changes: 196 additions & 0 deletions core/trino-main/src/main/java/io/trino/operator/LastNOperator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator;

import io.trino.spi.Page;
import io.trino.sql.planner.plan.PlanNodeId;

import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Iterator;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class LastNOperator
implements Operator
{
public static class LastNOperatorFactory
implements OperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final int count;
private boolean closed;

public LastNOperatorFactory(int operatorId, PlanNodeId planNodeId, int count)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.count = count;
}

@Override
public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LastNOperator.class.getSimpleName());
return new LastNOperator(operatorContext, count);
}

@Override
public void noMoreOperators()
{
closed = true;
}

@Override
public OperatorFactory duplicate()
{
return new LastNOperatorFactory(operatorId, planNodeId, count);
}
}

private enum State
{
NEEDS_INPUT,
FINISHED
}

private final OperatorContext operatorContext;
private long count;
private long totalPositions;
private State state = State.NEEDS_INPUT;
private long lastPageRowGroupId;
private final ArrayDeque<Page> outputPages;
private boolean canOutput;
private int remainingLimit;
private ArrayDeque<Page> currentQueue;
private ArrayDeque<ArrayDeque<Page>> totalQueue;

public LastNOperator(OperatorContext operatorContext, int count)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
checkArgument(count >= 0, "count must be at least zero");
this.count = count;
this.remainingLimit = count;
this.currentQueue = new ArrayDeque<>();
this.totalQueue = new ArrayDeque<>();
this.outputPages = new ArrayDeque<>();
}

@Override
public OperatorContext getOperatorContext()
{
return operatorContext;
}

@Override
public void finish()
{
state = State.FINISHED;
}

@Override
public boolean isFinished()
{
return state == State.FINISHED && totalQueue == null;
}

@Override
public boolean needsInput()
{
return count > 0 && state != State.FINISHED;
}

@Override
public void addInput(Page page)
{
checkState(state == State.NEEDS_INPUT, "Operator is already finishing");
long currPageRowGroupId = getCurrentPageRowGroupId(page);
if (currPageRowGroupId != lastPageRowGroupId) {
totalPositions = 0;
if (!currentQueue.isEmpty()) {
totalQueue.add(reverseDeque(currentQueue));
int countInLastRowGroup = 0;
for (Page currPage : currentQueue) {
countInLastRowGroup += currPage.getPositionCount();
}
count = count - countInLastRowGroup;
currentQueue = new ArrayDeque<>();
}
}
lastPageRowGroupId = currPageRowGroupId;
totalPositions += page.getPositionCount();
currentQueue.add(page);
while (!currentQueue.isEmpty() && totalPositions - currentQueue.peek().getPositionCount() >= count) {
totalPositions -= currentQueue.remove().getPositionCount();
}
}

private ArrayDeque<Page> reverseDeque(ArrayDeque<Page> deque)
{
ArrayDeque<Page> reversedDeque = new ArrayDeque<>();
Iterator<Page> descendingIterator = deque.descendingIterator();
while (descendingIterator.hasNext()) {
reversedDeque.add(descendingIterator.next());
}
return reversedDeque;
}

private long getCurrentPageRowGroupId(Page page)
{
return page.getBlock(page.getChannelCount() - 1).getLong(0, 0);
}

@Override
public Page getOutput()
{
if (count > 0 && state != State.FINISHED) {
return null;
}
if (!canOutput) {
if (!currentQueue.isEmpty()) {
totalQueue.add(reverseDeque(currentQueue));
}
ArrayDeque<Page> pages = totalQueue.stream().flatMap(Collection::stream).collect(Collectors.toCollection(ArrayDeque::new));
for (Page page : pages) {
if (page.getPositionCount() <= remainingLimit) {
remainingLimit = remainingLimit - page.getPositionCount();
outputPages.add(page);
}
else {
Page region = page.getRegion(page.getPositionCount() - remainingLimit, remainingLimit);
outputPages.add(region);
}
}
canOutput = true;
}
if (outputPages.isEmpty()) {
state = State.FINISHED;
totalQueue = null;
return null;
}
return outputPages.removeFirst();
}

@Override
public void close()
throws Exception
{
totalQueue = null;
}
}
Loading

0 comments on commit b345b77

Please sign in to comment.