Skip to content

Commit

Permalink
Do not create local projection for constant in PlanRemoteProjections
Browse files Browse the repository at this point in the history
Previously in PlanRemoteProjection, we might create extra assignments for
constants. This would cause plan failure in case of dereference. Essentially
we might rewrite
  `DEREFERENCE(expr, 0)`
to
  `DEREFERENCE(expr, expr1)`
    `expr1 = 0`

This was convenient since we didn't need to handle constant differently.
However, in RowExpressionInterpreter, it is assumed that the second parameter
of DEREFERENCE will be a number, which is no longer true with the change
in PlanRemoteProjection.

This PR changes the behavior of PlanRemoteProjections to not generate additional
projection for constant.
  • Loading branch information
rongrong authored and Rongrong Zhong committed Apr 26, 2021
1 parent 53ebcda commit 421d7f6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,14 @@ public List<ProjectionContext> visitCall(CallExpression call, Void context)
boolean local = !functionMetadata.getImplementationType().equals(THRIFT);

// Break function arguments into local and remote projections first
ImmutableList.Builder<VariableReferenceExpression> newArgumentsBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> newArgumentsBuilder = ImmutableList.builder();
List<ProjectionContext> processedArguments = processArguments(call.getArguments(), newArgumentsBuilder);
List<VariableReferenceExpression> newArguments = newArgumentsBuilder.build();
List<RowExpression> newArguments = newArgumentsBuilder.build();
CallExpression newCall = new CallExpression(
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
newArguments.stream()
.map(RowExpression.class::cast)
.collect(toImmutableList()));
newArguments);

if (local) {
if (processedArguments.size() == 1 && !processedArguments.get(0).isRemote()) {
Expand All @@ -280,7 +278,7 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
call.getFunctionHandle(),
call.getType(),
newArguments.stream()
.map(last.getProjections()::get)
.map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
Expand All @@ -306,13 +304,13 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
@Override
public List<ProjectionContext> visitInputReference(InputReferenceExpression reference, Void context)
{
return ImmutableList.of();
throw new IllegalStateException("Optimizers should not see InputReferenceExpression");
}

@Override
public List<ProjectionContext> visitConstant(ConstantExpression literal, Void context)
{
return ImmutableList.of();
throw new IllegalStateException("We should not create ProjectionContext for constants");
}

@Override
Expand All @@ -330,9 +328,9 @@ public List<ProjectionContext> visitVariableReference(VariableReferenceExpressio
@Override
public List<ProjectionContext> visitSpecialForm(SpecialFormExpression specialForm, Void context)
{
ImmutableList.Builder<VariableReferenceExpression> newArgumentsBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> newArgumentsBuilder = ImmutableList.builder();
List<ProjectionContext> processedArguments = processArguments(specialForm.getArguments(), newArgumentsBuilder);
List<VariableReferenceExpression> newArguments = newArgumentsBuilder.build();
List<RowExpression> newArguments = newArgumentsBuilder.build();
if (processedArguments.size() == 1 && !processedArguments.get(0).isRemote()) {
// Arguments do not contain remote projection
return ImmutableList.of();
Expand All @@ -349,7 +347,7 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
specialForm.getForm(),
specialForm.getType(),
newArguments.stream()
.map(last.getProjections()::get)
.map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
Expand All @@ -364,27 +362,30 @@ else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
new SpecialFormExpression(
specialForm.getForm(),
specialForm.getType(),
newArguments.stream()
.map(RowExpression.class::cast)
.collect(toImmutableList()))),
newArguments)),
false));
return projectionContextBuilder.build();
}
}

private List<ProjectionContext> processArguments(List<RowExpression> arguments, ImmutableList.Builder<VariableReferenceExpression> newArguments)
private List<ProjectionContext> processArguments(List<RowExpression> arguments, ImmutableList.Builder<RowExpression> newArguments)
{
// Break function arguments into local and remote projections first
ImmutableList.Builder<List<ProjectionContext>> argumentProjections = ImmutableList.builder();

for (RowExpression argument : arguments) {
List<ProjectionContext> argumentProjection = argument.accept(this, null);
if (argumentProjection.isEmpty()) {
VariableReferenceExpression variable = variableAllocator.newVariable(argument);
argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false));
if (argument instanceof ConstantExpression) {
newArguments.add(argument);
}
else {
List<ProjectionContext> argumentProjection = argument.accept(this, null);
if (argumentProjection.isEmpty()) {
VariableReferenceExpression variable = variableAllocator.newVariable(argument);
argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false));
}
argumentProjections.add(argumentProjection);
newArguments.add(getAssignedArgument(argumentProjection));
}
argumentProjections.add(argumentProjection);
newArguments.add(getAssignedArgument(argumentProjection));
}
return mergeProjectionContexts(argumentProjections.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ void testLocalOnly()
assertEquals(rewritten.get(0).getProjections().size(), 2);
}

@Test
void testRemoteWithConstantArgument()
{
PlanBuilder planBuilder = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), getMetadata());

PlanRemotePojections rule = new PlanRemotePojections(getFunctionAndTypeManager());
List<ProjectionContext> rewritten = rule.planRemoteAssignments(Assignments.builder()
.put(planBuilder.variable("a"), planBuilder.rowExpression("unittest.memory.remote_foo(0)"))
.put(planBuilder.variable("b"), planBuilder.rowExpression("unittest.memory.remote_foo()"))
.build(), new PlanVariableAllocator(planBuilder.getTypes().allVariables()));
assertEquals(rewritten.size(), 1);
}

@Test
void testRemoteOnly()
{
Expand Down Expand Up @@ -223,7 +236,7 @@ void testMixedExpressionRewrite()
p.variable("y", INTEGER);
return p.project(
Assignments.builder()
.put(p.variable("a"), p.rowExpression("unittest.memory.remote_foo(x, y + unittest.memory.remote_foo(x))")) // identity
.put(p.variable("a"), p.rowExpression("unittest.memory.remote_foo(1, y + unittest.memory.remote_foo(x))")) // identity
.put(p.variable("b"), p.rowExpression("x IS NULL OR y IS NULL")) // complex expression referenced multiple times
.put(p.variable("c"), p.rowExpression("abs(unittest.memory.remote_foo()) > 0")) // complex expression referenced multiple times
.put(p.variable("d"), p.rowExpression("unittest.memory.remote_foo(x + y, abs(x))")) // literal referenced multiple times
Expand All @@ -233,35 +246,31 @@ void testMixedExpressionRewrite()
.matches(
project(
ImmutableMap.of(
"a", PlanMatchPattern.expression("unittest.memory.remote_foo(x, add)"),
"a", PlanMatchPattern.expression("unittest.memory.remote_foo(1, add)"),
"b", PlanMatchPattern.expression("b"),
"c", PlanMatchPattern.expression("c"),
"d", PlanMatchPattern.expression("d")),
project(
ImmutableMap.of(
"x", PlanMatchPattern.expression("x"),
"add", PlanMatchPattern.expression("y + unittest_memory_remote_foo"),
"b", PlanMatchPattern.expression("b"),
"c", PlanMatchPattern.expression("abs(unittest_memory_remote_foo_7) > expr_8"),
"c", PlanMatchPattern.expression("abs(unittest_memory_remote_foo_7) > 0"),
"d", PlanMatchPattern.expression("d")),
project(
ImmutableMap.<String, ExpressionMatcher>builder()
.put("x", PlanMatchPattern.expression("x"))
.put("y", PlanMatchPattern.expression("y"))
.put("unittest_memory_remote_foo", PlanMatchPattern.expression("unittest.memory.remote_foo(x)"))
.put("b", PlanMatchPattern.expression("b"))
.put("unittest_memory_remote_foo_7", PlanMatchPattern.expression("unittest.memory.remote_foo()"))
.put("expr_8", PlanMatchPattern.expression("expr_8"))
.put("d", PlanMatchPattern.expression("unittest.memory.remote_foo(add_14, abs_16)"))
.put("d", PlanMatchPattern.expression("unittest.memory.remote_foo(add_9, abs_11)"))
.build(),
project(
ImmutableMap.<String, ExpressionMatcher>builder()
.put("x", PlanMatchPattern.expression("x"))
.put("y", PlanMatchPattern.expression("y"))
.put("b", PlanMatchPattern.expression("x IS NULL OR y is NULL"))
.put("expr_8", PlanMatchPattern.expression("0"))
.put("add_14", PlanMatchPattern.expression("x + y"))
.put("abs_16", PlanMatchPattern.expression("abs(x)"))
.put("add_9", PlanMatchPattern.expression("x + y"))
.put("abs_11", PlanMatchPattern.expression("abs(x)"))
.build(),
values(ImmutableMap.of("x", 0, "y", 1)))))));
}
Expand Down

0 comments on commit 421d7f6

Please sign in to comment.