Skip to content

Commit

Permalink
Use InterpretedFunctionInvoker in ExpressionOptimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 30, 2020
1 parent 95cfa60 commit 67291d1
Showing 1 changed file with 9 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import io.prestosql.Session;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.InterpretedFunctionInvoker;
import io.prestosql.sql.relational.CallExpression;
import io.prestosql.sql.relational.ConstantExpression;
import io.prestosql.sql.relational.InputReferenceExpression;
Expand All @@ -34,12 +32,9 @@
import io.prestosql.sql.relational.SpecialForm;
import io.prestosql.sql.relational.VariableReferenceExpression;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.type.FunctionType;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -48,17 +43,11 @@
import static io.prestosql.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME;
import static io.prestosql.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME;
import static io.prestosql.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.prestosql.spi.function.OperatorType.CAST;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static io.prestosql.sql.relational.Expressions.call;
import static io.prestosql.sql.relational.Expressions.constant;
import static io.prestosql.sql.relational.Expressions.constantNull;
import static io.prestosql.sql.relational.SpecialForm.Form.BIND;
import static io.prestosql.type.JsonType.JSON;

Expand Down Expand Up @@ -107,61 +96,23 @@ public RowExpression visitCall(CallExpression call, Void context)
// TODO: optimize function calls with lambda arguments. For example, apply(x -> x + 2, 1)
FunctionMetadata functionMetadata = metadata.getFunctionMetadata(call.getResolvedFunction());
if (arguments.stream().allMatch(ConstantExpression.class::isInstance) && functionMetadata.isDeterministic()) {
InvocationConvention convention = getInvocationConvention(call.getResolvedFunction(), functionMetadata);
MethodHandle method = metadata.getScalarFunctionInvoker(call.getResolvedFunction(), Optional.of(convention)).getMethodHandle();

List<Object> constantArguments = new ArrayList<>();
if (method.type().parameterCount() > 0 && method.type().parameterType(0) == ConnectorSession.class) {
constantArguments.add(session);
}

int index = 0;
for (RowExpression argument : arguments) {
Object value = ((ConstantExpression) argument).getValue();
// if any argument is null, return null
if (value == null && !functionMetadata.getArgumentDefinitions().get(index).isNullable()) {
return constantNull(call.getType());
}
constantArguments.add(value);
index++;
}
List<Object> constantArguments = arguments.stream()
.map(ConstantExpression.class::cast)
.map(ConstantExpression::getValue)
.collect(Collectors.toList());

try {
return constant(method.invokeWithArguments(constantArguments), call.getType());
InterpretedFunctionInvoker invoker = new InterpretedFunctionInvoker(metadata);
return constant(invoker.invoke(call.getResolvedFunction(), session, constantArguments), call.getType());
}
catch (Throwable e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
catch (RuntimeException e) {
// Do nothing. As a result, this specific tree will be left untouched. But irrelevant expressions will continue to get evaluated and optimized.
}
}

return call(call.getResolvedFunction(), metadata.getType(call.getResolvedFunction().getSignature().getReturnType()), arguments);
}

private InvocationConvention getInvocationConvention(ResolvedFunction function, FunctionMetadata functionMetadata)
{
ImmutableList.Builder<InvocationArgumentConvention> argumentConventions = ImmutableList.builder();
for (int i = 0; i < functionMetadata.getArgumentDefinitions().size(); i++) {
if (function.getSignature().getArgumentTypes().get(i).getBase().equalsIgnoreCase(FunctionType.NAME)) {
argumentConventions.add(FUNCTION);
}
else if (functionMetadata.getArgumentDefinitions().get(i).isNullable()) {
argumentConventions.add(BOXED_NULLABLE);
}
else {
argumentConventions.add(NEVER_NULL);
}
}

return new InvocationConvention(
argumentConventions.build(),
functionMetadata.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL,
true,
true);
}

@Override
public RowExpression visitSpecialForm(SpecialForm specialForm, Void context)
{
Expand Down

0 comments on commit 67291d1

Please sign in to comment.