Skip to content

Commit

Permalink
Extract outer symbols from lambda expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
sopel39 committed Oct 3, 2024
1 parent 5ceaffb commit 7955c15
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static Set<Symbol> extractUnique(WindowNode.Function function)
public static List<Symbol> extractAll(Expression expression)
{
ImmutableList.Builder<Symbol> builder = ImmutableList.builder();
new SymbolBuilderVisitor().process(expression, builder);
new SymbolBuilderVisitor().process(expression, new Context(ImmutableSet.of(), builder));
return builder.build();
}

Expand Down Expand Up @@ -132,20 +132,31 @@ public static Set<Symbol> extractOutputSymbols(PlanNode planNode, Lookup lookup)
}

private static class SymbolBuilderVisitor
extends DefaultTraversalVisitor<ImmutableList.Builder<Symbol>>
extends DefaultTraversalVisitor<Context>
{
@Override
protected Void visitReference(Reference node, ImmutableList.Builder<Symbol> builder)
protected Void visitReference(Reference node, Context context)
{
builder.add(Symbol.from(node));
Symbol symbol = Symbol.from(node);
if (!context.lambdaArguments().contains(symbol)) {
context.builder().add(symbol);
}
return null;
}

@Override
protected Void visitLambda(Lambda node, ImmutableList.Builder<Symbol> context)
protected Void visitLambda(Lambda node, Context context)
{
// Symbols in lambda expression are bound to lambda arguments, so no need to extract them
Context lambdaContext = new Context(
ImmutableSet.<Symbol>builder()
.addAll(context.lambdaArguments())
.addAll(node.arguments())
.build(),
context.builder());
process(node.body(), lambdaContext);
return null;
}
}

private record Context(Set<Symbol> lambdaArguments, ImmutableList.Builder<Symbol> builder) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.Assignments;
import io.trino.type.FunctionType;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
Expand Down Expand Up @@ -143,6 +145,23 @@ public void testEliminatesIdentityProjection()
values("x")));
}

@Test
public void testInlineLambda()
{
tester().assertThat(new InlineProjections())
.on(p -> p.project(
Assignments.of(
p.symbol("complex", INTEGER), new Reference(INTEGER, "complex"),
p.symbol("output", new FunctionType(ImmutableList.of(BIGINT), INTEGER)),
new Lambda(ImmutableList.of(p.symbol("arg")),
new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "arg"), new Reference(INTEGER, "complex"))))),
p.project(Assignments.builder()
.put(p.symbol("complex", INTEGER), new Call(SUBTRACT_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 1L))))
.build(),
p.values(p.symbol("x", INTEGER)))))
.doesNotFire();
}

@Test
public void testIdentityProjections()
{
Expand Down

0 comments on commit 7955c15

Please sign in to comment.