Skip to content

Commit

Permalink
Change code generator to use ScalarFunctionInvoker
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 30, 2020
1 parent d57f5aa commit b04a1ad
Show file tree
Hide file tree
Showing 20 changed files with 283 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.prestosql.metadata;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
Expand All @@ -22,11 +23,13 @@ public class FunctionInvoker
{
private final MethodHandle methodHandle;
private final Optional<MethodHandle> instanceFactory;
private final List<Optional<Class<?>>> lambdaInterfaces;

public FunctionInvoker(MethodHandle methodHandle, Optional<MethodHandle> instanceFactory)
public FunctionInvoker(MethodHandle methodHandle, Optional<MethodHandle> instanceFactory, List<Optional<Class<?>>> lambdaInterfaces)
{
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null");
this.lambdaInterfaces = requireNonNull(lambdaInterfaces, "lambdaInterfaces is null");
}

public MethodHandle getMethodHandle()
Expand All @@ -38,4 +41,9 @@ public Optional<MethodHandle> getInstanceFactory()
{
return instanceFactory;
}

public List<Optional<Class<?>>> getLambdaInterfaces()
{
return lambdaInterfaces;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
import io.prestosql.spi.type.Type;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.metadata.ScalarFunctionAdapter.NullAdaptationPolicy.UNSUPPORTED;
Expand All @@ -38,6 +41,7 @@
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static java.lang.String.format;
import static java.util.Comparator.comparingInt;
import static java.util.Objects.requireNonNull;

class FunctionInvokerProvider
Expand All @@ -54,17 +58,29 @@ public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalar
{
InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(scalarFunctionImplementation));

List<Choice> choices = new ArrayList<>();
for (ScalarImplementationChoice choice : scalarFunctionImplementation.getAllChoices()) {
InvocationConvention callingConvention = toCallingConvention(choice);
if (functionAdapter.canAdapt(callingConvention, expectedConvention)) {
List<Type> actualTypes = resolvedSignature.getArgumentTypes().stream()
.map(metadata::getType)
.collect(toImmutableList());
MethodHandle methodHandle = functionAdapter.adapt(choice.getMethodHandle(), actualTypes, callingConvention, expectedConvention);
return new FunctionInvoker(methodHandle, choice.getInstanceFactory());
choices.add(new Choice(choice, callingConvention));
}
}
throw new PrestoException(FUNCTION_NOT_FOUND, format("Dependent function implementation (%s) with convention (%s) is not available", resolvedSignature, invocationConvention.toString()));
if (choices.isEmpty()) {
throw new PrestoException(FUNCTION_NOT_FOUND,
format("Function implementation for (%s) cannot be adapted to convention (%s)", resolvedSignature, invocationConvention));
}

Choice bestChoice = Collections.max(choices, comparingInt(Choice::getScore));
List<Type> actualTypes = resolvedSignature.getArgumentTypes().stream()
.map(metadata::getType)
.collect(toImmutableList());
MethodHandle methodHandle = functionAdapter.adapt(bestChoice.getChoice().getMethodHandle(), actualTypes, bestChoice.getCallingConvention(), expectedConvention);
return new FunctionInvoker(
methodHandle,
bestChoice.getChoice().getInstanceFactory(),
bestChoice.getChoice().getArgumentProperties().stream()
.map(ArgumentProperty::getLambdaInterface)
.collect(Collectors.toList()));
}

/**
Expand Down Expand Up @@ -111,4 +127,43 @@ private static InvocationArgumentConvention toArgumentConvention(ArgumentPropert
throw new IllegalArgumentException("Unsupported null convention: " + argumentProperty.getNullConvention());
}
}

private static final class Choice
{
private final ScalarImplementationChoice choice;
private final InvocationConvention callingConvention;
private final int score;

public Choice(ScalarImplementationChoice choice, InvocationConvention callingConvention)
{
this.choice = requireNonNull(choice, "choice is null");
this.callingConvention = requireNonNull(callingConvention, "callingConvention is null");

int score = 0;
for (InvocationArgumentConvention argument : callingConvention.getArgumentConventions()) {
if (argument == NULL_FLAG) {
score += 1;
}
else if (argument == BLOCK_POSITION) {
score += 1000;
}
}
this.score = score;
}

public ScalarImplementationChoice getChoice()
{
return choice;
}

public InvocationConvention getCallingConvention()
{
return callingConvention;
}

public int getScore()
{
return score;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,19 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in
Type toType = boundVariables.getTypeVariable("T");

ResolvedFunction resolvedFunction = metadata.getCoercion(fromType, toType);
ScalarFunctionImplementation function = metadata.getScalarFunctionImplementation(resolvedFunction);
Class<?> castOperatorClass = generateArrayCast(metadata, resolvedFunction.getSignature(), function);
Class<?> castOperatorClass = generateArrayCast(metadata, resolvedFunction);
MethodHandle methodHandle = methodHandle(castOperatorClass, "castArray", ConnectorSession.class, Block.class);
return new ScalarFunctionImplementation(
false,
ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)),
methodHandle);
}

private static Class<?> generateArrayCast(Metadata metadata, Signature elementCastSignature, ScalarFunctionImplementation elementCast)
private static Class<?> generateArrayCast(Metadata metadata, ResolvedFunction elementCast)
{
CallSiteBinder binder = new CallSiteBinder();

Signature elementCastSignature = elementCast.getSignature();
ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName(Joiner.on("$").join("ArrayCast", elementCastSignature.getArgumentTypes().get(0), elementCastSignature.getReturnType())),
Expand All @@ -116,7 +116,7 @@ private static Class<?> generateArrayCast(Metadata metadata, Signature elementCa
Type fromElementType = metadata.getType(elementCastSignature.getArgumentTypes().get(0));
Type toElementType = metadata.getType(elementCastSignature.getReturnType());
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(definition, binder);
ArrayMapBytecodeExpression newArray = ArrayGeneratorUtils.map(scope, cachedInstanceBinder, fromElementType, toElementType, value, elementCastSignature.getName(), elementCast);
ArrayMapBytecodeExpression newArray = ArrayGeneratorUtils.map(scope, cachedInstanceBinder, fromElementType, toElementType, value, elementCast, metadata);

// return the block
body.append(newArray.ret());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,10 @@ private static Class<?> generateIndeterminate(Type type, Metadata metadata)
.gotoLabel(end));

ResolvedFunction resolvedFunction = metadata.resolveOperator(INDETERMINATE, ImmutableList.of(fieldTypes.get(i)));
ScalarFunctionImplementation function = metadata.getScalarFunctionImplementation(resolvedFunction);
BytecodeExpression element = constantType(binder, fieldTypes.get(i)).getValue(value, constantInt(i));

ifNullField.ifFalse(new IfStatement("if the field is not null but indeterminate...")
.condition(invokeFunction(scope, cachedInstanceBinder, resolvedFunction.getSignature().getName(), function, element))
.condition(invokeFunction(scope, cachedInstanceBinder, resolvedFunction, metadata, element))
.ifTrue(new BytecodeBlock()
.push(true)
.gotoLabel(end)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,13 @@ private static Class<?> generateRowCast(Type fromType, Type toType, Metadata met
// loop through to append member blocks
for (int i = 0; i < toTypes.size(); i++) {
ResolvedFunction resolvedFunction = metadata.getCoercion(fromTypes.get(i), toTypes.get(i));
ScalarFunctionImplementation function = metadata.getScalarFunctionImplementation(resolvedFunction);
Type currentFromType = fromTypes.get(i);
if (currentFromType.equals(UNKNOWN)) {
body.append(singleRowBlockWriter.invoke("appendNull", BlockBuilder.class).pop());
continue;
}
BytecodeExpression fromElement = constantType(binder, currentFromType).getValue(value, constantInt(i));
BytecodeExpression toElement = invokeFunction(scope, cachedInstanceBinder, resolvedFunction.getSignature().getName(), function, fromElement);
BytecodeExpression toElement = invokeFunction(scope, cachedInstanceBinder, resolvedFunction, metadata, fromElement);
IfStatement ifElementNull = new IfStatement("if the element in the row type is null...");

ifElementNull.condition(value.invoke("isNull", boolean.class, constantInt(i)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,9 @@ public NullConvention getNullConvention()
return nullConvention.get();
}

public Class<?> getLambdaInterface()
public Optional<Class<?>> getLambdaInterface()
{
checkState(getArgumentType() == FUNCTION_TYPE, "lambdaInterface only applies to function type argument");
return lambdaInterface.get();
return lambdaInterface;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ private static MethodType javaMethodType(ParametricScalarImplementationChoice ch
}
break;
case FUNCTION_TYPE:
methodHandleParameterTypes.add(argumentProperty.getLambdaInterface());
methodHandleParameterTypes.add(argumentProperty.getLambdaInterface().orElseThrow(() -> new IllegalArgumentException("Argument is not a function")));
break;
default:
throw new UnsupportedOperationException("unknown ArgumentType");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public Object invoke(ResolvedFunction function, ConnectorSession session, List<O
}
}
else {
argument = asInterfaceInstance(argumentProperty.getLambdaInterface(), (MethodHandle) argument);
argument = asInterfaceInstance(method.type().parameterType(i), (MethodHandle) argument);
actualArguments.add(argument);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.type.Type;

import java.util.function.Function;
Expand All @@ -27,15 +28,22 @@ public final class ArrayGeneratorUtils
{
private ArrayGeneratorUtils() {}

public static ArrayMapBytecodeExpression map(Scope scope, CachedInstanceBinder cachedInstanceBinder, Type fromElementType, Type toElementType, Variable array, String elementFunctionName, ScalarFunctionImplementation elementFunction)
public static ArrayMapBytecodeExpression map(
Scope scope,
CachedInstanceBinder cachedInstanceBinder,
Type fromElementType,
Type toElementType,
Variable array,
ResolvedFunction elementFunction,
Metadata metadata)
{
return map(
scope,
cachedInstanceBinder.getCallSiteBinder(),
fromElementType,
toElementType,
array,
element -> invokeFunction(scope, cachedInstanceBinder, elementFunctionName, elementFunction, element));
element -> invokeFunction(scope, cachedInstanceBinder, elementFunction, metadata, element));
}

public static ArrayMapBytecodeExpression map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
package io.prestosql.sql.gen;

import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.FieldDefinition;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.prestosql.metadata.Metadata;
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.sql.relational.RowExpression;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.sql.gen.BytecodeUtils.generateFullInvocation;
import static io.prestosql.sql.gen.BytecodeUtils.generateInvocation;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -69,12 +72,7 @@ public CallSiteBinder getCallSiteBinder()

public BytecodeNode generate(RowExpression expression)
{
return generate(expression, Optional.empty());
}

public BytecodeNode generate(RowExpression expression, Optional<Class<?>> lambdaInterface)
{
return rowExpressionCompiler.compile(expression, scope, lambdaInterface);
return rowExpressionCompiler.compile(expression, scope);
}

public Metadata getMetadata()
Expand All @@ -85,14 +83,25 @@ public Metadata getMetadata()
/**
* Generates a function call with null handling, automatic binding of session parameter, etc.
*/
public BytecodeNode generateCall(String name, ScalarFunctionImplementation function, List<BytecodeNode> arguments)
public BytecodeNode generateCall(ResolvedFunction resolvedFunction, List<BytecodeNode> arguments)
{
return generateInvocation(scope, resolvedFunction, metadata, arguments, callSiteBinder);
}

public BytecodeNode generateFullCall(ResolvedFunction resolvedFunction, List<RowExpression> arguments)
{
List<Function<Optional<Class<?>>, BytecodeNode>> argumentCompilers = arguments.stream()
.map(this::argumentCompiler)
.collect(toImmutableList());

Function<MethodHandle, BytecodeNode> instance = instanceFactory -> scope.getThis().getField(cachedInstanceBinder.getCachedInstance(instanceFactory));

return generateFullInvocation(scope, resolvedFunction, metadata, instance, argumentCompilers, callSiteBinder);
}

private Function<Optional<Class<?>>, BytecodeNode> argumentCompiler(RowExpression argument)
{
Optional<BytecodeNode> instance = Optional.empty();
if (function.getInstanceFactory().isPresent()) {
FieldDefinition field = cachedInstanceBinder.getCachedInstance(function.getInstanceFactory().get());
instance = Optional.of(scope.getThis().getField(field));
}
return generateInvocation(scope, name, function, instance, arguments, callSiteBinder);
return lambdaInterface -> rowExpressionCompiler.compile(argument, scope, lambdaInterface);
}

public Variable wasNull()
Expand Down
Loading

0 comments on commit b04a1ad

Please sign in to comment.