Skip to content

Commit

Permalink
Replace getScalarFunctiontImplementation with getScalarFunctionInvoker
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 30, 2020
1 parent b04a1ad commit 79318c6
Show file tree
Hide file tree
Showing 27 changed files with 192 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.NamedTypeSignature;
Expand All @@ -57,11 +58,14 @@
import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.prestosql.metadata.MetadataManager.createTestMetadataManager;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.OperatorType.IS_DISTINCT_FROM;
import static io.prestosql.spi.type.Decimals.encodeScaledValue;

Expand Down Expand Up @@ -211,7 +215,8 @@ public static Slice longDecimal(String value)
public static MethodHandle distinctFromOperator(Type type)
{
ResolvedFunction function = METADATA.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(type, type));
return METADATA.getScalarFunctionImplementation(function).getMethodHandle();
InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG, NULL_FLAG), FAIL_ON_NULL, false, false);
return METADATA.getScalarFunctionInvoker(function, Optional.of(invocationConvention)).getMethodHandle();
}

public static boolean isDistinctFrom(MethodHandle handle, Block left, Block right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention;
import io.prestosql.spi.type.Type;
import io.prestosql.type.FunctionType;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
Expand Down Expand Up @@ -54,9 +55,13 @@ public FunctionInvokerProvider(Metadata metadata)
this.metadata = requireNonNull(metadata, "metadata is null");
}

public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalarFunctionImplementation, Signature resolvedSignature, Optional<InvocationConvention> invocationConvention)
public FunctionInvoker createFunctionInvoker(
FunctionMetadata functionMetadata,
ScalarFunctionImplementation scalarFunctionImplementation,
Signature resolvedSignature,
Optional<InvocationConvention> invocationConvention)
{
InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(scalarFunctionImplementation));
InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(functionMetadata));

List<Choice> choices = new ArrayList<>();
for (ScalarImplementationChoice choice : scalarFunctionImplementation.getAllChoices()) {
Expand All @@ -67,7 +72,7 @@ public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalar
}
if (choices.isEmpty()) {
throw new PrestoException(FUNCTION_NOT_FOUND,
format("Function implementation for (%s) cannot be adapted to convention (%s)", resolvedSignature, invocationConvention));
format("Function implementation for (%s) cannot be adapted to convention (%s)", resolvedSignature, expectedConvention));
}

Choice bestChoice = Collections.max(choices, comparingInt(Choice::getScore));
Expand All @@ -87,13 +92,12 @@ public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalar
* Default calling convention is no nulls and null is never returned. Since the no nulls adaptation strategy is to fail, the scalar must have this
* exact convention or convention must be specified.
*/
private static InvocationConvention getDefaultCallingConvention(ScalarFunctionImplementation scalarFunctionImplementation)
private static InvocationConvention getDefaultCallingConvention(FunctionMetadata functionMetadata)
{
List<InvocationArgumentConvention> argumentConventions = scalarFunctionImplementation.getArgumentProperties().stream()
.map(ArgumentProperty::getArgumentType)
.map(argumentProperty -> argumentProperty == FUNCTION_TYPE ? FUNCTION : NEVER_NULL)
List<InvocationArgumentConvention> argumentConventions = functionMetadata.getSignature().getArgumentTypes().stream()
.map(typeSignature -> typeSignature.getBase().equalsIgnoreCase(FunctionType.NAME) ? FUNCTION : NEVER_NULL)
.collect(toImmutableList());
InvocationReturnConvention returnConvention = scalarFunctionImplementation.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL;
InvocationReturnConvention returnConvention = functionMetadata.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL;

return new InvocationConvention(argumentConventions, returnConvention, true, false);
}
Expand Down
3 changes: 0 additions & 3 deletions presto-main/src/main/java/io/prestosql/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import io.prestosql.Session;
import io.prestosql.connector.CatalogName;
import io.prestosql.operator.aggregation.InternalAggregationFunction;
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.operator.window.WindowFunctionSupplier;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.BlockEncoding;
Expand Down Expand Up @@ -479,8 +478,6 @@ default ResolvedFunction getCoercion(Type fromType, Type toType)

FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention);

ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction);

ProcedureRegistry getProcedureRegistry();

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1514,15 +1514,10 @@ public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFu
@Override
public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention)
{
FunctionMetadata functionMetadata = getFunctionMetadata(resolvedFunction);
ScalarFunctionImplementation scalarFunctionImplementation = functions.getScalarFunctionImplementation(this, resolvedFunction);
FunctionInvokerProvider functionInvokerProvider = new FunctionInvokerProvider(this);
return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention);
}

@Override
public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction)
{
return functions.getScalarFunctionImplementation(this, resolvedFunction);
return functionInvokerProvider.createFunctionInvoker(functionMetadata, scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.Page;
import io.prestosql.spi.PageBuilder;
import io.prestosql.spi.block.Block;
Expand Down Expand Up @@ -70,8 +71,8 @@ public SimplePagesHashStrategy(
requireNonNull(metadata, "metadata is null");
ImmutableList.Builder<MethodHandle> distinctFromMethodHandlesBuilder = ImmutableList.builder();
for (int i = 0; i < hashChannels.size(); i++) {
distinctFromMethodHandlesBuilder.add(
metadata.getScalarFunctionImplementation(metadata.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(types.get(i), types.get(i)))).getMethodHandle());
ResolvedFunction resolvedFunction = metadata.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(types.get(i), types.get(i)));
distinctFromMethodHandlesBuilder.add(metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle());
}
distinctFromMethodHandles = distinctFromMethodHandlesBuilder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
Expand Down Expand Up @@ -106,7 +107,8 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des
public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, Metadata metadata)
{
Type type = boundVariables.getTypeVariable("E");
MethodHandle compareMethodHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(operatorType, ImmutableList.of(type, type))).getMethodHandle();
ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(type, type));
MethodHandle compareMethodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();
return generateAggregation(type, compareMethodHandle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.aggregation.AccumulatorCompiler;
Expand Down Expand Up @@ -156,7 +157,8 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key

CallSiteBinder binder = new CallSiteBinder();
OperatorType operator = min ? LESS_THAN : GREATER_THAN;
MethodHandle compareMethod = metadata.getScalarFunctionImplementation(metadata.resolveOperator(operator, ImmutableList.of(keyType, keyType))).getMethodHandle();
ResolvedFunction resolvedFunction = metadata.resolveOperator(operator, ImmutableList.of(keyType, keyType));
MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();

ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.connector.RecordCursor;
import io.prestosql.spi.connector.RecordSet;
import io.prestosql.spi.type.Type;

import java.lang.invoke.MethodHandle;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -51,9 +53,9 @@ public FieldSetFilteringRecordSet(Metadata metadata, RecordSet delegate, List<Se
for (Set<Integer> fieldSet : requireNonNull(fieldSets, "fieldSets is null")) {
ImmutableSet.Builder<Field> fieldSetBuilder = ImmutableSet.builder();
for (int field : fieldSet) {
fieldSetBuilder.add(new Field(
field,
metadata.getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(columnTypes.get(field), columnTypes.get(field)))).getMethodHandle()));
ResolvedFunction resolvedFunction = metadata.resolveOperator(EQUAL, ImmutableList.of(columnTypes.get(field), columnTypes.get(field)));
MethodHandle methodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();
fieldSetBuilder.add(new Field(field, methodHandle));
}
fieldSetsBuilder.add(fieldSetBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlScalarFunction;
import io.prestosql.spi.PrestoException;
Expand All @@ -37,6 +38,7 @@

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -97,7 +99,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in
Type type = boundVariables.getTypeVariable("E");
checkArgument(type.isOrderable(), "Type must be orderable");

MethodHandle compareMethod = metadata.getScalarFunctionImplementation(metadata.resolveOperator(operatorType, ImmutableList.of(type, type))).getMethodHandle();
ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(type, type));
MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();

List<Class<?>> javaTypes = IntStream.range(0, arity)
.mapToObj(i -> type.getJavaType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlScalarFunction;
import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty;
Expand Down Expand Up @@ -168,7 +169,8 @@ private static ScalarFunctionImplementation specializeArrayJoin(Map<String, Type
}
else {
try {
ScalarFunctionImplementation castFunction = metadata.getScalarFunctionImplementation(metadata.getCoercion(type, VARCHAR));
ResolvedFunction resolvedFunction = metadata.getCoercion(type, VARCHAR);
MethodHandle cast = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();

MethodHandle getter;
Class<?> elementType = type.getJavaType();
Expand All @@ -188,8 +190,6 @@ else if (elementType == Slice.class) {
throw new UnsupportedOperationException("Unsupported type: " + elementType.getName());
}

MethodHandle cast = castFunction.getMethodHandle();

// if the cast doesn't take a ConnectorSession, create an adapter that drops the provided session
if (cast.type().parameterArray()[0] != ConnectorSession.class) {
cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.time.ZonedDateTime;
import java.util.IllegalFormatException;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;

import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -193,7 +194,7 @@ private static BiFunction<ConnectorSession, Block, Object> valueConverter(Metada
// TODO: support TIME WITH TIME ZONE by https://github.com/prestosql/presto/issues/191 + mapping to java.time.OffsetTime
if (type.equals(JSON)) {
ResolvedFunction function = metadata.resolveFunction(QualifiedName.of("json_format"), fromTypes(JSON));
MethodHandle handle = metadata.getScalarFunctionImplementation(function).getMethodHandle();
MethodHandle handle = metadata.getScalarFunctionInvoker(function, Optional.empty()).getMethodHandle();
return (session, block) -> convertToString(handle, type.getSlice(block, position));
}
if (isShortDecimal(type)) {
Expand Down Expand Up @@ -241,7 +242,7 @@ private static MethodHandle castToVarchar(Metadata metadata, Type type)
{
try {
ResolvedFunction cast = metadata.getCoercion(type, VARCHAR);
return metadata.getScalarFunctionImplementation(cast).getMethodHandle();
return metadata.getScalarFunctionInvoker(cast, Optional.empty()).getMethodHandle();
}
catch (OperatorNotFoundException e) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
import io.prestosql.spi.block.DuplicateMapKeyException;
import io.prestosql.spi.block.MapBlockBuilder;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.TypeSignatureParameter;

import java.lang.invoke.MethodHandle;
import java.util.Optional;
Expand All @@ -43,9 +42,14 @@
import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty;
import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL;
import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.OperatorType.EQUAL;
import static io.prestosql.spi.function.OperatorType.HASH_CODE;
import static io.prestosql.spi.function.OperatorType.INDETERMINATE;
import static io.prestosql.spi.type.StandardTypes.MAP;
import static io.prestosql.spi.type.TypeSignature.arrayType;
import static io.prestosql.spi.type.TypeSignatureParameter.typeParameter;
import static io.prestosql.spi.type.TypeUtils.readNativeValue;
import static io.prestosql.util.Failures.checkCondition;
import static io.prestosql.util.Failures.internalError;
Expand Down Expand Up @@ -97,10 +101,11 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in
Type keyType = boundVariables.getTypeVariable("K");
Type valueType = boundVariables.getTypeVariable("V");

Type mapType = metadata.getParameterizedType(MAP, ImmutableList.of(TypeSignatureParameter.typeParameter(keyType.getTypeSignature()), TypeSignatureParameter.typeParameter(valueType.getTypeSignature())));
MethodHandle keyHashCode = metadata.getScalarFunctionImplementation(metadata.resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(keyType))).getMethodHandle();
MethodHandle keyEqual = metadata.getScalarFunctionImplementation(metadata.resolveOperator(OperatorType.EQUAL, ImmutableList.of(keyType, keyType))).getMethodHandle();
MethodHandle keyIndeterminate = metadata.getScalarFunctionImplementation(metadata.resolveOperator(INDETERMINATE, ImmutableList.of(keyType))).getMethodHandle();
Type mapType = metadata.getParameterizedType(MAP, ImmutableList.of(typeParameter(keyType.getTypeSignature()), typeParameter(valueType.getTypeSignature())));
MethodHandle keyHashCode = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(keyType)), Optional.empty()).getMethodHandle();
MethodHandle keyEqual = metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(keyType, keyType)), Optional.empty()).getMethodHandle();
InvocationConvention indeterminateCallingConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG), FAIL_ON_NULL, false, false);
MethodHandle keyIndeterminate = metadata.getScalarFunctionInvoker(metadata.resolveOperator(INDETERMINATE, ImmutableList.of(keyType)), Optional.of(indeterminateCallingConvention)).getMethodHandle();
MethodHandle instanceFactory = constructorMethodHandle(State.class, MapType.class).bindTo(mapType);

return new ScalarFunctionImplementation(
Expand Down
Loading

0 comments on commit 79318c6

Please sign in to comment.