Skip to content

Commit

Permalink
Make FunctionInvokerProvider an internal detail of Metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 30, 2020
1 parent 0abb2c8 commit 7b6768d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class FunctionInvokerProvider
class FunctionInvokerProvider
{
private final ScalarFunctionAdapter functionAdapter = new ScalarFunctionAdapter(UNSUPPORTED);
private final Metadata metadata;
Expand All @@ -50,23 +50,21 @@ public FunctionInvokerProvider(Metadata metadata)
this.metadata = requireNonNull(metadata, "metadata is null");
}

public FunctionInvoker createFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention)
public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalarFunctionImplementation, Signature resolvedSignature, Optional<InvocationConvention> invocationConvention)
{
ScalarFunctionImplementation scalarFunctionImplementation = metadata.getScalarFunctionImplementation(resolvedFunction);

InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(scalarFunctionImplementation));

for (ScalarImplementationChoice choice : scalarFunctionImplementation.getAllChoices()) {
InvocationConvention callingConvention = toCallingConvention(choice);
if (functionAdapter.canAdapt(callingConvention, expectedConvention)) {
List<Type> actualTypes = resolvedFunction.getSignature().getArgumentTypes().stream()
List<Type> actualTypes = resolvedSignature.getArgumentTypes().stream()
.map(metadata::getType)
.collect(toImmutableList());
MethodHandle methodHandle = functionAdapter.adapt(choice.getMethodHandle(), actualTypes, callingConvention, expectedConvention);
return new FunctionInvoker(methodHandle, choice.getInstanceFactory());
}
}
throw new PrestoException(FUNCTION_NOT_FOUND, format("Dependent function implementation (%s) with convention (%s) is not available", resolvedFunction, invocationConvention.toString()));
throw new PrestoException(FUNCTION_NOT_FOUND, format("Dependent function implementation (%s) with convention (%s) is not available", resolvedSignature, invocationConvention.toString()));
}

/**
Expand Down
5 changes: 3 additions & 2 deletions presto-main/src/main/java/io/prestosql/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.security.GrantInfo;
Expand Down Expand Up @@ -448,8 +449,6 @@ default Type getParameterizedType(String baseTypeName, List<TypeSignatureParamet

List<FunctionMetadata> listFunctions();

FunctionInvokerProvider getFunctionInvokerProvider();

ResolvedFunction resolveFunction(QualifiedName name, List<TypeSignatureProvider> parameterTypes);

ResolvedFunction resolveOperator(OperatorType operatorType, List<? extends Type> argumentTypes)
Expand Down Expand Up @@ -478,6 +477,8 @@ default ResolvedFunction getCoercion(Type fromType, Type toType)

InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFunction resolvedFunction);

FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention);

ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction);

ProcedureRegistry getProcedureRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.security.GrantInfo;
Expand Down Expand Up @@ -1408,12 +1409,6 @@ public List<FunctionMetadata> listFunctions()
return functions.list();
}

@Override
public FunctionInvokerProvider getFunctionInvokerProvider()
{
return new FunctionInvokerProvider(this);
}

@Override
public ResolvedFunction resolveFunction(QualifiedName name, List<TypeSignatureProvider> parameterTypes)
{
Expand Down Expand Up @@ -1516,6 +1511,14 @@ public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFu
return functions.getAggregateFunctionImplementation(this, resolvedFunction);
}

@Override
public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention)
{
ScalarFunctionImplementation scalarFunctionImplementation = functions.getScalarFunctionImplementation(this, resolvedFunction);
FunctionInvokerProvider functionInvokerProvider = new FunctionInvokerProvider(this);
return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention);
}

@Override
public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

public abstract class ScalarImplementationDependency
implements ImplementationDependency
{
private final Optional<InvocationConvention> invocationConvention;

protected ScalarImplementationDependency(Optional<InvocationConvention> invocationConvention)
{
this.invocationConvention = invocationConvention;
this.invocationConvention = requireNonNull(invocationConvention, "invocationConvention is null");
if (invocationConvention.map(InvocationConvention::supportsInstanceFactor).orElse(false)) {
throw new IllegalArgumentException(getClass().getSimpleName() + " does not support instance functions");
}
}

protected abstract ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata);
Expand All @@ -37,7 +42,7 @@ protected ScalarImplementationDependency(Optional<InvocationConvention> invocati
public MethodHandle resolve(BoundVariables boundVariables, Metadata metadata)
{
ResolvedFunction resolvedFunction = getResolvedFunction(boundVariables, metadata);
return metadata.getFunctionInvokerProvider().createFunctionInvoker(resolvedFunction, invocationConvention).getMethodHandle();
return metadata.getScalarFunctionInvoker(resolvedFunction, invocationConvention).getMethodHandle();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in
Type type = boundVariables.getTypeVariable("T");
for (Type parameterType : type.getTypeParameters()) {
ResolvedFunction resolvedFunction = metadata.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(parameterType, parameterType));
FunctionInvoker functionInvoker = metadata.getFunctionInvokerProvider().createFunctionInvoker(
FunctionInvoker functionInvoker = metadata.getScalarFunctionInvoker(
resolvedFunction,
Optional.of(new InvocationConvention(
ImmutableList.of(NULL_FLAG, NULL_FLAG),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.security.GrantInfo;
Expand Down Expand Up @@ -573,12 +574,6 @@ public List<FunctionMetadata> listFunctions()
throw new UnsupportedOperationException();
}

@Override
public FunctionInvokerProvider getFunctionInvokerProvider()
{
throw new UnsupportedOperationException();
}

@Override
public ResolvedFunction resolveFunction(QualifiedName name, List<TypeSignatureProvider> parameterTypes)
{
Expand Down Expand Up @@ -643,6 +638,12 @@ public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFu
throw new UnsupportedOperationException();
}

@Override
public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention)
{
throw new UnsupportedOperationException();
}

@Override
public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction)
{
Expand Down

0 comments on commit 7b6768d

Please sign in to comment.