Skip to content

Commit

Permalink
Inline Window call
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Jun 18, 2019
1 parent 748f3fa commit 527cc40
Show file tree
Hide file tree
Showing 19 changed files with 120 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -900,10 +900,9 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext

FrameInfo frameInfo = new FrameInfo(frame.getType(), frame.getStartType(), frameStartChannel, frame.getEndType(), frameEndChannel);

FunctionCall functionCall = entry.getValue().getFunctionCall();
Signature signature = entry.getValue().getSignature();
ImmutableList.Builder<Integer> arguments = ImmutableList.builder();
for (Expression argument : functionCall.getArguments()) {
for (Expression argument : entry.getValue().getArguments()) {
Symbol argumentSymbol = Symbol.from(argument);
arguments.add(source.getLayout().get(argumentSymbol));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,9 @@ private PlanBuilder window(PlanBuilder subPlan, List<FunctionCall> windowFunctio
outputTranslations.put(windowFunction, newSymbol);

WindowNode.Function function = new WindowNode.Function(
(FunctionCall) rewritten, analysis.getFunctionSignature(windowFunction), frame);
analysis.getFunctionSignature(windowFunction),
((FunctionCall) rewritten).getArguments(),
frame);

List<Symbol> sourceSymbols = subPlan.getRoot().getOutputSymbols();
ImmutableList.Builder<Symbol> orderBySymbols = ImmutableList.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.AggregationNode.Aggregation;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor;
import io.prestosql.sql.tree.DefaultTraversalVisitor;
import io.prestosql.sql.tree.DereferenceExpression;
Expand Down Expand Up @@ -84,6 +85,11 @@ public static Set<Symbol> extractUnique(Aggregation aggregation)
return ImmutableSet.copyOf(extractAll(aggregation));
}

public static Set<Symbol> extractUnique(WindowNode.Function function)
{
return ImmutableSet.copyOf(extractAll(function));
}

public static List<Symbol> extractAll(Expression expression)
{
ImmutableList.Builder<Symbol> builder = ImmutableList.builder();
Expand All @@ -102,6 +108,17 @@ public static List<Symbol> extractAll(Aggregation aggregation)
return builder.build();
}

public static List<Symbol> extractAll(WindowNode.Function function)
{
ImmutableList.Builder<Symbol> builder = ImmutableList.builder();
for (Expression argument : function.getArguments()) {
builder.addAll(extractAll(argument));
}
function.getFrame().getEndValue().ifPresent(builder::add);
function.getFrame().getStartValue().ifPresent(builder::add);
return builder.build();
}

// to extract qualified name with prefix
public static Set<QualifiedName> extractNames(Expression expression, Set<NodeRef<Expression>> columnReferences)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.FrameBound;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.GenericLiteral;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.Window;
import io.prestosql.sql.tree.WindowFrame;

import java.util.Optional;
Expand Down Expand Up @@ -80,16 +77,6 @@ public Result apply(LimitNode parent, Captures captures, Context context)
PlanNode child = captures.get(CHILD);
Symbol rankSymbol = context.getSymbolAllocator().newSymbol("rank_num", BIGINT);

FunctionCall functionCall = new FunctionCall(
QualifiedName.of("rank"),
Optional.of(
new Window(
ImmutableList.of(),
Optional.empty(),
Optional.empty())),
false,
ImmutableList.of());

Signature signature = new Signature(
"rank",
FunctionKind.WINDOW,
Expand All @@ -106,8 +93,8 @@ public Result apply(LimitNode parent, Captures captures, Context context)
Optional.empty());

WindowNode.Function rankFunction = new WindowNode.Function(
functionCall,
signature,
ImmutableList.of(),
frame);

WindowNode windowNode = new WindowNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator,
windowNode.getHashSymbol().ifPresent(referencedInputs::add);

for (WindowNode.Function windowFunction : referencedFunctions.values()) {
referencedInputs.addAll(SymbolsExtractor.extractUnique(windowFunction.getFunctionCall()));
windowFunction.getFrame().getStartValue().ifPresent(referencedInputs::add);
windowFunction.getFrame().getEndValue().ifPresent(referencedInputs::add);
referencedInputs.addAll(SymbolsExtractor.extractUnique(windowFunction));
}

PlanNode prunedWindowNode = new WindowNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedIndex;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.sql.planner.DomainTranslator;
Expand All @@ -45,8 +46,10 @@
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.planner.plan.WindowNode.Function;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.sql.tree.WindowFrame;

Expand Down Expand Up @@ -349,7 +352,9 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Context> context)
public PlanNode visitWindow(WindowNode node, RewriteContext<Context> context)
{
if (!node.getWindowFunctions().values().stream()
.map(function -> function.getFunctionCall().getName())
.map(Function::getSignature)
.map(Signature::getName)
.map(QualifiedName::of)
.allMatch(metadata::isAggregationFunction)) {
return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -405,9 +404,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Set<Symbol>> context
WindowNode.Function function = entry.getValue();

if (context.get().contains(symbol)) {
FunctionCall call = function.getFunctionCall();
expectedInputs.addAll(SymbolsExtractor.extractUnique(call));

expectedInputs.addAll(SymbolsExtractor.extractUnique(function));
functionsBuilder.put(symbol, entry.getValue());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.SymbolReference;

Expand Down Expand Up @@ -196,11 +195,11 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
for (Map.Entry<Symbol, WindowNode.Function> entry : node.getWindowFunctions().entrySet()) {
Symbol symbol = entry.getKey();

FunctionCall canonicalFunctionCall = (FunctionCall) canonicalize(entry.getValue().getFunctionCall());
Signature signature = entry.getValue().getSignature();
List<Expression> arguments = canonicalize(entry.getValue().getArguments());
WindowNode.Frame canonicalFrame = canonicalize(entry.getValue().getFrame());

functions.put(canonicalize(symbol), new WindowNode.Function(canonicalFunctionCall, signature, canonicalFrame));
functions.put(canonicalize(symbol), new WindowNode.Function(signature, arguments, canonicalFrame));
}

return new WindowNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public static boolean dependsOn(WindowNode parent, WindowNode child)
return parent.getPartitionBy().stream().anyMatch(child.getCreatedSymbols()::contains)
|| (parent.getOrderingScheme().isPresent() && parent.getOrderingScheme().get().getOrderBy().stream().anyMatch(child.getCreatedSymbols()::contains))
|| parent.getWindowFunctions().values().stream()
.map(WindowNode.Function::getFunctionCall)
.map(SymbolsExtractor::extractUnique)
.flatMap(Collection::stream)
.anyMatch(child.getCreatedSymbols()::contains);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FrameBound;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.WindowFrame;

import javax.annotation.concurrent.Immutable;
Expand Down Expand Up @@ -325,31 +324,31 @@ public int hashCode()
@Immutable
public static final class Function
{
private final FunctionCall functionCall;
private final Signature signature;
private final List<Expression> arguments;
private final Frame frame;

@JsonCreator
public Function(
@JsonProperty("functionCall") FunctionCall functionCall,
@JsonProperty("signature") Signature signature,
@JsonProperty("arguments") List<Expression> arguments,
@JsonProperty("frame") Frame frame)
{
this.functionCall = requireNonNull(functionCall, "functionCall is null");
this.signature = requireNonNull(signature, "Signature is null");
this.arguments = requireNonNull(arguments, "arguments is null");
this.frame = requireNonNull(frame, "Frame is null");
}

@JsonProperty
public FunctionCall getFunctionCall()
public Signature getSignature()
{
return functionCall;
return signature;
}

@JsonProperty
public Signature getSignature()
public List<Expression> getArguments()
{
return signature;
return arguments;
}

@JsonProperty
Expand All @@ -361,7 +360,7 @@ public Frame getFrame()
@Override
public int hashCode()
{
return Objects.hash(functionCall, signature, frame);
return Objects.hash(signature, arguments, frame);
}

@Override
Expand All @@ -374,8 +373,8 @@ public boolean equals(Object obj)
return false;
}
Function other = (Function) obj;
return Objects.equals(this.functionCall, other.functionCall) &&
Objects.equals(this.signature, other.signature) &&
return Objects.equals(this.signature, other.signature) &&
Objects.equals(this.arguments, other.arguments) &&
Objects.equals(this.frame, other.frame);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
import io.prestosql.sql.planner.planprinter.NodeRepresentation.TypedSymbol;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.util.GraphvizPrinter;

Expand Down Expand Up @@ -606,10 +605,15 @@ public Void visitWindow(WindowNode node, Void context)
NodeRepresentation nodeOutput = addNode(node, "Window", format("[%s]%s", Joiner.on(", ").join(args), formatHash(node.getHashSymbol())));

for (Map.Entry<Symbol, WindowNode.Function> entry : node.getWindowFunctions().entrySet()) {
FunctionCall call = entry.getValue().getFunctionCall();
String frameInfo = formatFrame(entry.getValue().getFrame());
WindowNode.Function function = entry.getValue();
String frameInfo = formatFrame(function.getFrame());

nodeOutput.appendDetailsLine("%s := %s(%s) %s", entry.getKey(), call.getName(), Joiner.on(", ").join(call.getArguments()), frameInfo);
nodeOutput.appendDetailsLine(
"%s := %s(%s) %s",
entry.getKey(),
function.getSignature().getName(),
Joiner.on(", ").join(function.getArguments()),
frameInfo);
}
return processChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,10 @@ public Void visitUnion(UnionNode node, Void context)

private void checkWindowFunctions(Map<Symbol, WindowNode.Function> functions)
{
for (Map.Entry<Symbol, WindowNode.Function> entry : functions.entrySet()) {
Signature signature = entry.getValue().getSignature();
FunctionCall call = entry.getValue().getFunctionCall();

checkSignature(entry.getKey(), signature);
checkCall(entry.getKey(), call);
}
functions.forEach((symbol, function) -> {
checkSignature(symbol, function.getSignature());
checkCall(symbol, function.getSignature().getName(), function.getArguments());
});
}

private void checkSignature(Symbol symbol, Signature signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public Void visitWindow(WindowNode node, Set<Symbol> boundSymbols)
checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputSymbols());

for (WindowNode.Function function : node.getWindowFunctions().values()) {
Set<Symbol> dependencies = SymbolsExtractor.extractUnique(function.getFunctionCall());
Set<Symbol> dependencies = SymbolsExtractor.extractUnique(function);
checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FrameBound;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.WindowFrame;
import io.prestosql.testing.TestingMetadata.TestingColumnHandle;
import org.testng.annotations.BeforeMethod;
Expand Down Expand Up @@ -149,7 +147,6 @@ public void testValidWindow()
DOUBLE.getTypeSignature(),
ImmutableList.of(DOUBLE.getTypeSignature()),
false);
FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference()));

WindowNode.Frame frame = new WindowNode.Frame(
WindowFrame.Type.RANGE,
Expand All @@ -160,7 +157,7 @@ public void testValidWindow()
Optional.empty(),
Optional.empty());

WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame);
WindowNode.Function function = new WindowNode.Function(signature, ImmutableList.of(columnC.toSymbolReference()), frame);

WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty());

Expand Down Expand Up @@ -311,7 +308,6 @@ public void testInvalidWindowFunctionCall()
DOUBLE.getTypeSignature(),
ImmutableList.of(DOUBLE.getTypeSignature()),
false);
FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference())); // should be columnC

WindowNode.Frame frame = new WindowNode.Frame(
WindowFrame.Type.RANGE,
Expand All @@ -322,7 +318,7 @@ public void testInvalidWindowFunctionCall()
Optional.empty(),
Optional.empty());

WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame);
WindowNode.Function function = new WindowNode.Function(signature, ImmutableList.of(columnA.toSymbolReference()), frame);

WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty());

Expand Down Expand Up @@ -350,7 +346,6 @@ public void testInvalidWindowFunctionSignature()
BIGINT.getTypeSignature(), // should be DOUBLE
ImmutableList.of(DOUBLE.getTypeSignature()),
false);
FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference()));

WindowNode.Frame frame = new WindowNode.Frame(
WindowFrame.Type.RANGE,
Expand All @@ -361,7 +356,7 @@ public void testInvalidWindowFunctionSignature()
Optional.empty(),
Optional.empty());

WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame);
WindowNode.Function function = new WindowNode.Function(signature, ImmutableList.of(columnC.toSymbolReference()), frame);

WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty());

Expand Down
Loading

0 comments on commit 527cc40

Please sign in to comment.