Skip to content

Commit

Permalink
Simplify lambda interface handling in FunctionInvoker
Browse files Browse the repository at this point in the history
Match lambda interface handling to aggregation and window functions
  • Loading branch information
dain committed Oct 7, 2020
1 parent e4be51a commit 2298c18
Show file tree
Hide file tree
Showing 16 changed files with 50 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package io.prestosql.metadata;

import com.google.common.collect.ImmutableList;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;
Expand All @@ -23,9 +25,16 @@ public class FunctionInvoker
{
private final MethodHandle methodHandle;
private final Optional<MethodHandle> instanceFactory;
private final List<Optional<Class<?>>> lambdaInterfaces;
private final List<Class<?>> lambdaInterfaces;

public FunctionInvoker(MethodHandle methodHandle, Optional<MethodHandle> instanceFactory)
{
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null");
this.lambdaInterfaces = ImmutableList.of();
}

public FunctionInvoker(MethodHandle methodHandle, Optional<MethodHandle> instanceFactory, List<Optional<Class<?>>> lambdaInterfaces)
public FunctionInvoker(MethodHandle methodHandle, Optional<MethodHandle> instanceFactory, List<Class<?>> lambdaInterfaces)
{
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null");
Expand All @@ -42,7 +51,7 @@ public Optional<MethodHandle> getInstanceFactory()
return instanceFactory;
}

public List<Optional<Class<?>>> getLambdaInterfaces()
public List<Class<?>> getLambdaInterfaces()
{
return lambdaInterfaces;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import static com.google.common.base.Preconditions.checkState;
import static java.util.Collections.emptyList;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;

class PolymorphicScalarFunction
Expand Down Expand Up @@ -88,7 +87,7 @@ private ScalarImplementationChoice getScalarFunctionImplementationChoice(
return new ScalarImplementationChoice(
choice.getReturnConvention(),
choice.getArgumentConventions(),
nCopies(choice.getArgumentConventions().size(), Optional.empty()),
ImmutableList.of(),
methodHandle,
Optional.empty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
NULLABLE_RETURN,
ImmutableList.of(BOXED_NULLABLE, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.of(UnaryFunctionInterface.class)),
ImmutableList.of(UnaryFunctionInterface.class),
METHOD_HANDLE.asType(
METHOD_HANDLE.type()
.changeReturnType(Primitives.wrap(returnType.getJavaType()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
NULLABLE_RETURN,
ImmutableList.of(NEVER_NULL, BOXED_NULLABLE, FUNCTION, FUNCTION),
ImmutableList.of(
Optional.empty(),
Optional.empty(),
Optional.of(BinaryFunctionInterface.class),
Optional.of(UnaryFunctionInterface.class)),
ImmutableList.of(BinaryFunctionInterface.class, UnaryFunctionInterface.class),
methodHandle.asType(
methodHandle.type()
.changeParameterType(1, Primitives.wrap(intermediateType.getJavaType()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.of(UnaryFunctionInterface.class)),
ImmutableList.of(UnaryFunctionInterface.class),
methodHandle(generatedClass, "transform", PageBuilder.class, Block.class, UnaryFunctionInterface.class),
Optional.of(methodHandle(generatedClass, "createPageBuilder")));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
NULLABLE_RETURN,
ImmutableList.of(FUNCTION),
ImmutableList.of(Optional.of(InvokeLambda.class)),
ImmutableList.of(InvokeLambda.class),
METHOD_HANDLE.asType(
METHOD_HANDLE.type()
.changeReturnType(Primitives.wrap(returnType.getJavaType()))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public ScalarFunctionImplementation specialize(FunctionBinding functionBinding)
return new ScalarFunctionImplementation(
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)),
ImmutableList.of(BinaryFunctionInterface.class),
generateFilter(mapType),
Optional.of(STATE_FACTORY.bindTo(mapType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)),
ImmutableList.of(BinaryFunctionInterface.class),
generateTransformKey(keyType, transformedKeyType, valueType, resultMapType),
Optional.of(STATE_FACTORY.bindTo(resultMapType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)),
ImmutableList.of(BinaryFunctionInterface.class),
generateTransform(keyType, valueType, transformedValueType, resultMapType),
Optional.of(STATE_FACTORY.bindTo(resultMapType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, NEVER_NULL, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.empty(), Optional.of(MapZipWithLambda.class)),
ImmutableList.of(MapZipWithLambda.class),
METHOD_HANDLE.bindTo(keyType).bindTo(inputValueType1).bindTo(inputValueType2).bindTo(outputMapType),
Optional.of(STATE_FACTORY.bindTo(outputMapType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ public ScalarFunctionImplementation specialize(FunctionBinding functionBinding,
new ScalarImplementationChoice(
FAIL_ON_NULL,
ImmutableList.of(NULL_FLAG, NULL_FLAG),
ImmutableList.of(Optional.empty(), Optional.empty()),
ImmutableList.of(),
METHOD_HANDLE_NULL_FLAG.bindTo(type).bindTo(argumentMethods.build()),
Optional.empty()),
new ScalarImplementationChoice(
FAIL_ON_NULL,
ImmutableList.of(BLOCK_POSITION, BLOCK_POSITION),
ImmutableList.of(Optional.empty(), Optional.empty()),
ImmutableList.of(),
METHOD_HANDLE_BLOCK_POSITION.bindTo(type).bindTo(argumentMethods.build()),
Optional.empty())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;

public final class ScalarFunctionImplementation
Expand All @@ -36,7 +35,7 @@ public ScalarFunctionImplementation(
List<InvocationArgumentConvention> argumentConventions,
MethodHandle methodHandle)
{
this(returnConvention, argumentConventions, nCopies(argumentConventions.size(), Optional.empty()), methodHandle, Optional.empty());
this(returnConvention, argumentConventions, ImmutableList.of(), methodHandle, Optional.empty());
}

public ScalarFunctionImplementation(
Expand All @@ -45,13 +44,13 @@ public ScalarFunctionImplementation(
MethodHandle methodHandle,
Optional<MethodHandle> instanceFactory)
{
this(returnConvention, argumentConventions, nCopies(argumentConventions.size(), Optional.empty()), methodHandle, instanceFactory);
this(returnConvention, argumentConventions, ImmutableList.of(), methodHandle, instanceFactory);
}

public ScalarFunctionImplementation(
InvocationReturnConvention returnConvention,
List<InvocationArgumentConvention> argumentConventions,
List<Optional<Class<?>>> lambdaInterfaces,
List<Class<?>> lambdaInterfaces,
MethodHandle methodHandle,
Optional<MethodHandle> instanceFactory)
{
Expand Down Expand Up @@ -83,12 +82,12 @@ public static class ScalarImplementationChoice
private final MethodHandle methodHandle;
private final Optional<MethodHandle> instanceFactory;
private final InvocationConvention invocationConvention;
private final List<Optional<Class<?>>> lambdaInterfaces;
private final List<Class<?>> lambdaInterfaces;

public ScalarImplementationChoice(
InvocationReturnConvention returnConvention,
List<InvocationArgumentConvention> argumentConventions,
List<Optional<Class<?>>> lambdaInterfaces,
List<Class<?>> lambdaInterfaces,
MethodHandle methodHandle,
Optional<MethodHandle> instanceFactory)
{
Expand Down Expand Up @@ -120,7 +119,7 @@ public ScalarImplementationChoice(
returnConvention,
hasSession,
instanceFactory.isPresent());
checkArgument(lambdaInterfaces.size() == argumentConventions.size());
checkArgument(lambdaInterfaces.size() <= argumentConventions.size());
}

public MethodHandle getMethodHandle()
Expand All @@ -133,7 +132,7 @@ public Optional<MethodHandle> getInstanceFactory()
return instanceFactory;
}

public List<Optional<Class<?>>> getLambdaInterfaces()
public List<Class<?>> getLambdaInterfaces()
{
return lambdaInterfaces;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ protected ScalarFunctionImplementation specialize(FunctionBinding functionBindin
return new ScalarFunctionImplementation(
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, NEVER_NULL, FUNCTION),
ImmutableList.of(Optional.empty(), Optional.empty(), Optional.of(BinaryFunctionInterface.class)),
ImmutableList.of(BinaryFunctionInterface.class),
METHOD_HANDLE.bindTo(leftElementType).bindTo(rightElementType).bindTo(outputArrayType),
Optional.of(STATE_FACTORY.bindTo(outputArrayType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ private static MethodType javaMethodType(ParametricScalarImplementationChoice ch
}

List<InvocationArgumentConvention> argumentConventions = choice.getArgumentConventions();
int lambdaArgumentIndex = 0;
for (int i = 0; i < argumentConventions.size(); i++) {
InvocationArgumentConvention argumentConvention = argumentConventions.get(i);
Type signatureType = signature.getArgumentTypes().get(i);
Expand All @@ -259,7 +260,8 @@ private static MethodType javaMethodType(ParametricScalarImplementationChoice ch
methodHandleParameterTypes.add(int.class);
break;
case FUNCTION:
methodHandleParameterTypes.add(choice.getLambdaInterfaces().get(i).orElseThrow(() -> new IllegalArgumentException("Argument is not a function")));
methodHandleParameterTypes.add(choice.getLambdaInterfaces().get(lambdaArgumentIndex));
lambdaArgumentIndex++;
break;
default:
throw new UnsupportedOperationException("unknown argument convention: " + argumentConvention);
Expand Down Expand Up @@ -338,7 +340,7 @@ public static final class ParametricScalarImplementationChoice
{
private final InvocationReturnConvention returnConvention;
private final List<InvocationArgumentConvention> argumentConventions;
private final List<Optional<Class<?>>> lambdaInterfaces;
private final List<Class<?>> lambdaInterfaces;
private final MethodHandle methodHandle;
private final Optional<MethodHandle> constructor;
private final List<ImplementationDependency> dependencies;
Expand All @@ -350,7 +352,7 @@ private ParametricScalarImplementationChoice(
InvocationReturnConvention returnConvention,
boolean hasConnectorSession,
List<InvocationArgumentConvention> argumentConventions,
List<Optional<Class<?>>> lambdaInterfaces,
List<Class<?>> lambdaInterfaces,
MethodHandle methodHandle,
Optional<MethodHandle> constructor,
List<ImplementationDependency> dependencies,
Expand Down Expand Up @@ -397,7 +399,7 @@ public List<InvocationArgumentConvention> getArgumentConventions()
return argumentConventions;
}

public List<Optional<Class<?>>> getLambdaInterfaces()
public List<Class<?>> getLambdaInterfaces()
{
return lambdaInterfaces;
}
Expand Down Expand Up @@ -479,7 +481,7 @@ public static final class Parser
{
private final String functionName;
private final List<InvocationArgumentConvention> argumentConventions = new ArrayList<>();
private final List<Optional<Class<?>>> lambdaInterfaces = new ArrayList<>();
private final List<Class<?>> lambdaInterfaces = new ArrayList<>();
private final TypeSignature returnType;
private final List<TypeSignature> argumentTypes = new ArrayList<>();
private final List<Optional<Class<?>>> argumentNativeContainerTypes = new ArrayList<>();
Expand Down Expand Up @@ -595,7 +597,7 @@ private void parseArguments(Method method)
// function type
checkCondition(parameterType.isAnnotationPresent(FunctionalInterface.class), FUNCTION_IMPLEMENTATION_ERROR, "argument %s is marked as lambda but the function interface class is not annotated: %s", parameterIndex, methodHandle);
argumentConventions.add(FUNCTION);
lambdaInterfaces.add(Optional.of(parameterType));
lambdaInterfaces.add(parameterType);
argumentNativeContainerTypes.add(Optional.empty());
parameterIndex++;
}
Expand Down Expand Up @@ -652,7 +654,6 @@ else if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) {
}

argumentConventions.add(argumentConvention);
lambdaInterfaces.add(Optional.empty());
parameterIndex++;
if (argumentConvention == NULL_FLAG || argumentConvention == BLOCK_POSITION) {
parameterIndex++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.prestosql.type.FunctionType.NAME;
import static java.lang.invoke.MethodHandleProxies.asInterfaceInstance;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -85,6 +86,7 @@ public static Object invoke(FunctionMetadata functionMetadata, FunctionInvoker i
actualArguments.add(session);
}

int lambdaArgumentIndex = 0;
for (int i = 0; i < arguments.size(); i++) {
Object argument = arguments.get(i);

Expand All @@ -93,9 +95,9 @@ public static Object invoke(FunctionMetadata functionMetadata, FunctionInvoker i
return null;
}

Optional<Class<?>> lambdaInterface = invoker.getLambdaInterfaces().get(i);
if (lambdaInterface.isPresent()) {
argument = asInterfaceInstance(lambdaInterface.get(), (MethodHandle) argument);
if (functionMetadata.getSignature().getArgumentTypes().get(i).getBase().equals(NAME)) {
argument = asInterfaceInstance(invoker.getLambdaInterfaces().get(lambdaArgumentIndex), (MethodHandle) argument);
lambdaArgumentIndex++;
}

actualArguments.add(argument);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ else if (functionMetadata.getArgumentDefinitions().get(i).isNullable()) {
// Index of parameter (without @IsNull) in Presto function
int realParameterIndex = 0;

// Index of function argument types
int lambdaArgumentIndex = 0;

MethodType methodType = binding.getType();
Class<?> returnType = methodType.returnType();
Class<?> unboxedReturnType = Primitives.unwrap(returnType);
Expand Down Expand Up @@ -331,8 +334,9 @@ else if (type == ConnectorSession.class) {
currentParameterIndex++;
break;
case FUNCTION:
Optional<Class<?>> lambdaInterface = functionInvoker.getLambdaInterfaces().get(realParameterIndex);
block.append(argumentCompilers.get(realParameterIndex).apply(lambdaInterface));
Class<?> lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex);
block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface)));
lambdaArgumentIndex++;
break;
default:
throw new UnsupportedOperationException(format("Unsupported argument conventsion type: %s", invocationConvention.getArgumentConvention(realParameterIndex)));
Expand Down

0 comments on commit 2298c18

Please sign in to comment.