Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Sort field push down #848

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) {
ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);

// the first options is {"count": "integer"}
Integer count = (Integer) node.getOptions().get(0).getValue().getValue();
List<Pair<SortOption, Expression>> sortList =
node.getSortList().stream()
.map(
Expand All @@ -326,8 +324,7 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) {
return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression);
})
.collect(Collectors.toList());

return new LogicalSort(child, count, sortList);
return new LogicalSort(child, sortList);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public LogicalPlan visitWindowFunction(WindowFunction node, AnalysisContext cont
WindowDefinition windowDefinition = new WindowDefinition(partitionByList, sortList);

return new LogicalWindow(
new LogicalSort(child, 0, windowDefinition.getAllSortItems()),
new LogicalSort(child,windowDefinition.getAllSortItems()),
windowFunction,
windowDefinition);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,16 @@ public static List<Argument> defaultDedupArgs() {
argument("consecutive", booleanLiteral(false)));
}

public static List<Argument> defaultSortOptions() {
return exprList(argument("count", intLiteral(1000)), argument("desc", booleanLiteral(false)));
}

public static List<Argument> sortOptions(int count) {
return exprList(argument("count", intLiteral(count)), argument("desc", booleanLiteral(false)));
public static List<Argument> sortOptions() {
return exprList(argument("desc", booleanLiteral(false)));
}

public static List<Argument> defaultSortFieldArgs() {
return exprList(argument("asc", booleanLiteral(true)), argument("type", nullLiteral()));
}

public static Sort sort(UnresolvedPlan input, List<Argument> options, Field... sorts) {
return new Sort(input, options, Arrays.asList(sorts));
public static Sort sort(UnresolvedPlan input, Field... sorts) {
return new Sort(input, Arrays.asList(sorts));
}

public static Dedupe dedupe(UnresolvedPlan input, List<Argument> options, Field... fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOrder.DESC;

import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Argument;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field;
import com.google.common.collect.ImmutableList;
import java.util.List;
Expand All @@ -42,7 +41,6 @@
@AllArgsConstructor
public class Sort extends UnresolvedPlan {
private UnresolvedPlan child;
private final List<Argument> options;
private final List<Field> sortList;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public ExplainResponseNode visitFilter(FilterOperator node, Object context) {
@Override
public ExplainResponseNode visitSort(SortOperator node, Object context) {
return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of(
"count", node.getCount(),
"sortList", describeSortList(node.getSortList()))));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public PhysicalPlan visitEval(LogicalEval node, C context) {

@Override
public PhysicalPlan visitSort(LogicalSort node, C context) {
return new SortOperator(visitChild(node, context), node.getCount(), node.getSortList());
return new SortOperator(visitChild(node, context), node.getSortList());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ public static LogicalPlan eval(
return new LogicalEval(input, Arrays.asList(expressions));
}

public static LogicalPlan sort(
LogicalPlan input, Integer count, Pair<SortOption, Expression>... sorts) {
return new LogicalSort(input, count, Arrays.asList(sorts));
public static LogicalPlan sort(LogicalPlan input, Pair<SortOption, Expression>... sorts) {
return new LogicalSort(input, Arrays.asList(sorts));
}

public static LogicalPlan dedupe(LogicalPlan input, Expression... fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.apache.commons.lang3.tuple.Pair;

Expand All @@ -34,17 +32,15 @@
@EqualsAndHashCode(callSuper = true)
public class LogicalSort extends LogicalPlan {

private final Integer count;
private final List<Pair<SortOption, Expression>> sortList;

/**
* Constructor of LogicalSort.
*/
public LogicalSort(
LogicalPlan child, Integer count,
LogicalPlan child,
List<Pair<SortOption, Expression>> sortList) {
super(Collections.singletonList(child));
this.count = count;
this.sortList = sortList;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalPlan;
import com.amazon.opendistroforelasticsearch.sql.planner.optimizer.rule.MergeFilterAndFilter;
import com.amazon.opendistroforelasticsearch.sql.planner.optimizer.rule.PushFilterUnderSort;
import com.facebook.presto.matching.Match;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -50,7 +51,8 @@ public LogicalPlanOptimizer(List<Rule<?>> rules) {
*/
public static LogicalPlanOptimizer create(DSL dsl) {
return new LogicalPlanOptimizer(Arrays.asList(
new MergeFilterAndFilter(dsl)));
new MergeFilterAndFilter(dsl),
new PushFilterUnderSort()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.opendistroforelasticsearch.sql.planner.optimizer.rule;

import static com.amazon.opendistroforelasticsearch.sql.planner.optimizer.pattern.Patterns.source;
import static com.facebook.presto.matching.Pattern.typeOf;

import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalFilter;
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalPlan;
import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalSort;
import com.amazon.opendistroforelasticsearch.sql.planner.optimizer.Rule;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import lombok.Getter;
import lombok.experimental.Accessors;

/**
* Push Filter under Sort.
* Filter - Sort - Child --> Sort - Filter - Child
*/
public class PushFilterUnderSort implements Rule<LogicalFilter> {

private final Capture<LogicalSort> capture;

@Accessors(fluent = true)
@Getter
private final Pattern<LogicalFilter> pattern;

/**
* Constructor of PushFilterUnderSort.
*/
public PushFilterUnderSort() {
this.capture = Capture.newCapture();
this.pattern = typeOf(LogicalFilter.class)
.with(source().matching(typeOf(LogicalSort.class).capturedAs(capture)));
}

@Override
public LogicalPlan apply(LogicalFilter filter,
Captures captures) {
LogicalSort sort = captures.get(capture);
return new LogicalSort(
filter.replaceChildPlans(sort.getChild()),
sort.getSortList()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ public static EvalOperator eval(
return new EvalOperator(input, Arrays.asList(expressions));
}

public static SortOperator sort(PhysicalPlan input, Integer count, Pair<SortOption,
public static SortOperator sort(PhysicalPlan input, Pair<SortOption,
Expression>... sorts) {
return new SortOperator(input, count, Arrays.asList(sorts));
return new SortOperator(input, Arrays.asList(sorts));
}

public static DedupeOperator dedupe(PhysicalPlan input, Expression... expressions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.amazon.opendistroforelasticsearch.sql.data.utils.ExprValueOrdering;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.planner.physical.SortOperator.Sorter.SorterBuilder;
import com.google.common.collect.Iterators;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
Expand All @@ -46,11 +45,7 @@
public class SortOperator extends PhysicalPlan {
@Getter
private final PhysicalPlan input;
/**
* How many sorted result should been return. If count = 0, all the resulted will be returned.
*/
@Getter
private final Integer count;

@Getter
private final List<Pair<SortOption, Expression>> sortList;
@EqualsAndHashCode.Exclude
Expand All @@ -61,14 +56,12 @@ public class SortOperator extends PhysicalPlan {
/**
* Sort Operator Constructor.
* @param input input {@link PhysicalPlan}
* @param count how many sorted result should been return
* @param sortList list of sort sort field.
* The sort field is specified by the {@link Expression} with {@link SortOption}
*/
public SortOperator(
PhysicalPlan input, Integer count, List<Pair<SortOption, Expression>> sortList) {
PhysicalPlan input, List<Pair<SortOption, Expression>> sortList) {
this.input = input;
this.count = count;
this.sortList = sortList;
SorterBuilder sorterBuilder = Sorter.builder();
for (Pair<SortOption, Expression> pair : sortList) {
Expand Down Expand Up @@ -97,8 +90,7 @@ public void open() {
sorted.add(input.next());
}

Iterator<ExprValue> sortedIterator = iterator(sorted);
iterator = count == 0 ? sortedIterator : Iterators.limit(sortedIterator, count);
iterator = iterator(sorted);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ public void sort_with_aggregator() {
"avg(integer_value)",
dsl.avg(DSL.ref("integer_value", INTEGER)))),
ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))),
0,
// Aggregator in Sort AST node is replaced with reference by expression optimizer
Pair.of(SortOption.DEFAULT_ASC, DSL.ref("avg(integer_value)", DOUBLE))),
DSL.named("string_value", DSL.ref("string_value", STRING))),
Expand All @@ -314,7 +313,6 @@ public void sort_with_aggregator() {
ImmutableList.of(AstDSL.alias("string_value", qualifiedName("string_value"))),
emptyList()
),
ImmutableList.of(argument("count", intLiteral(0))),
field(
function("avg", qualifiedName("integer_value")),
argument("asc", booleanLiteral(true)))),
Expand Down Expand Up @@ -353,13 +351,11 @@ public void sort_with_options() {
LogicalPlanDSL.project(
LogicalPlanDSL.sort(
LogicalPlanDSL.relation("test"),
0,
Pair.of(expectOption, DSL.ref("integer_value", INTEGER))),
DSL.named("string_value", DSL.ref("string_value", STRING))),
AstDSL.project(
AstDSL.sort(
AstDSL.relation("test"),
ImmutableList.of(argument("count", intLiteral(0))),
field(qualifiedName("integer_value"), args)),
AstDSL.alias("string_value", qualifiedName("string_value")))));
}
Expand All @@ -372,7 +368,6 @@ public void window_function() {
LogicalPlanDSL.window(
LogicalPlanDSL.sort(
LogicalPlanDSL.relation("test"),
0,
ImmutablePair.of(DEFAULT_ASC, DSL.ref("string_value", STRING)),
ImmutablePair.of(DEFAULT_ASC, DSL.ref("integer_value", INTEGER))),
dsl.rowNumber(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed()
LogicalPlanDSL.window(
LogicalPlanDSL.sort(
LogicalPlanDSL.relation("test"),
0,
ImmutablePair.of(DEFAULT_ASC, DSL.ref("string_value", STRING)),
ImmutablePair.of(DEFAULT_DESC, DSL.ref("integer_value", INTEGER))),
dsl.rowNumber(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ void can_explain_other_operators() {
dedupe(
sort(
values(values),
1000,
sortList),
dedupeList),
evalExprs),
Expand Down Expand Up @@ -233,7 +232,6 @@ void can_explain_other_operators() {
singletonList(new ExplainResponseNode(
"SortOperator",
ImmutableMap.of(
"count", 1000,
"sortList", ImmutableMap.of(
"age", ImmutableMap.of(
"sortOrder", "ASC",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ public void visitShouldReturnDefaultPhysicalOperator() {
mappings),
exclude),
newEvalField),
sortCount,
sortField),
CommandType.TOP,
topByExprs,
Expand Down Expand Up @@ -150,7 +149,6 @@ public void visitShouldReturnDefaultPhysicalOperator() {
mappings),
exclude),
newEvalField),
sortCount,
sortField),
CommandType.TOP,
topByExprs,
Expand Down Expand Up @@ -192,7 +190,6 @@ public void visitWindowOperatorShouldReturnPhysicalWindowOperator() {
window(
sort(
values(),
0,
sortList),
windowFunction,
windowDefinition),
Expand All @@ -203,7 +200,6 @@ public void visitWindowOperatorShouldReturnPhysicalWindowOperator() {
PhysicalPlanDSL.window(
PhysicalPlanDSL.sort(
PhysicalPlanDSL.values(),
0,
sortList),
windowFunction,
windowDefinition),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() {
assertNull(eval.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));

LogicalPlan sort = LogicalPlanDSL.sort(relation, 100,
LogicalPlan sort = LogicalPlanDSL.sort(relation,
Pair.of(SortOption.DEFAULT_ASC, expression));
assertNull(sort.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));
Expand Down
Loading