diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java index 5fa9712125af6..4776b8001da72 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/HiveTestUtils.java @@ -43,6 +43,7 @@ import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.MapType; import io.prestosql.spi.type.NamedTypeSignature; @@ -57,11 +58,14 @@ import java.lang.invoke.MethodHandle; import java.math.BigDecimal; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.prestosql.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.prestosql.spi.type.Decimals.encodeScaledValue; @@ -211,7 +215,8 @@ public static Slice longDecimal(String value) public static MethodHandle distinctFromOperator(Type type) { ResolvedFunction function = METADATA.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(type, type)); - return METADATA.getScalarFunctionImplementation(function).getMethodHandle(); + InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG, NULL_FLAG), FAIL_ON_NULL, false, false); + return METADATA.getScalarFunctionInvoker(function, Optional.of(invocationConvention)).getMethodHandle(); } public static boolean isDistinctFrom(MethodHandle handle, Block left, Block right) diff --git a/presto-main/src/main/java/io/prestosql/metadata/FunctionInvokerProvider.java b/presto-main/src/main/java/io/prestosql/metadata/FunctionInvokerProvider.java index e77605a8accee..3f3a649626e14 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/FunctionInvokerProvider.java +++ b/presto-main/src/main/java/io/prestosql/metadata/FunctionInvokerProvider.java @@ -21,6 +21,7 @@ import io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention; import io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention; import io.prestosql.spi.type.Type; +import io.prestosql.type.FunctionType; import java.lang.invoke.MethodHandle; import java.util.ArrayList; @@ -54,9 +55,13 @@ public FunctionInvokerProvider(Metadata metadata) this.metadata = requireNonNull(metadata, "metadata is null"); } - public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalarFunctionImplementation, Signature resolvedSignature, Optional invocationConvention) + public FunctionInvoker createFunctionInvoker( + FunctionMetadata functionMetadata, + ScalarFunctionImplementation scalarFunctionImplementation, + Signature resolvedSignature, + Optional invocationConvention) { - InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(scalarFunctionImplementation)); + InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(functionMetadata)); List choices = new ArrayList<>(); for (ScalarImplementationChoice choice : scalarFunctionImplementation.getAllChoices()) { @@ -67,7 +72,7 @@ public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalar } if (choices.isEmpty()) { throw new PrestoException(FUNCTION_NOT_FOUND, - format("Function implementation for (%s) cannot be adapted to convention (%s)", resolvedSignature, invocationConvention)); + format("Function implementation for (%s) cannot be adapted to convention (%s)", resolvedSignature, expectedConvention)); } Choice bestChoice = Collections.max(choices, comparingInt(Choice::getScore)); @@ -87,13 +92,12 @@ public FunctionInvoker createFunctionInvoker(ScalarFunctionImplementation scalar * Default calling convention is no nulls and null is never returned. Since the no nulls adaptation strategy is to fail, the scalar must have this * exact convention or convention must be specified. */ - private static InvocationConvention getDefaultCallingConvention(ScalarFunctionImplementation scalarFunctionImplementation) + private static InvocationConvention getDefaultCallingConvention(FunctionMetadata functionMetadata) { - List argumentConventions = scalarFunctionImplementation.getArgumentProperties().stream() - .map(ArgumentProperty::getArgumentType) - .map(argumentProperty -> argumentProperty == FUNCTION_TYPE ? FUNCTION : NEVER_NULL) + List argumentConventions = functionMetadata.getSignature().getArgumentTypes().stream() + .map(typeSignature -> typeSignature.getBase().equalsIgnoreCase(FunctionType.NAME) ? FUNCTION : NEVER_NULL) .collect(toImmutableList()); - InvocationReturnConvention returnConvention = scalarFunctionImplementation.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL; + InvocationReturnConvention returnConvention = functionMetadata.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL; return new InvocationConvention(argumentConventions, returnConvention, true, false); } diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index 889c50327a6c0..3b7123a3e2202 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -17,7 +17,6 @@ import io.prestosql.Session; import io.prestosql.connector.CatalogName; import io.prestosql.operator.aggregation.InternalAggregationFunction; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; import io.prestosql.operator.window.WindowFunctionSupplier; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.BlockEncoding; @@ -479,8 +478,6 @@ default ResolvedFunction getCoercion(Type fromType, Type toType) FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional invocationConvention); - ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction); - ProcedureRegistry getProcedureRegistry(); // diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 6392131c515cd..39b20e79e7ff3 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -1514,15 +1514,10 @@ public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFu @Override public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional invocationConvention) { + FunctionMetadata functionMetadata = getFunctionMetadata(resolvedFunction); ScalarFunctionImplementation scalarFunctionImplementation = functions.getScalarFunctionImplementation(this, resolvedFunction); FunctionInvokerProvider functionInvokerProvider = new FunctionInvokerProvider(this); - return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention); - } - - @Override - public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction) - { - return functions.getScalarFunctionImplementation(this, resolvedFunction); + return functionInvokerProvider.createFunctionInvoker(functionMetadata, scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention); } @Override diff --git a/presto-main/src/main/java/io/prestosql/operator/SimplePagesHashStrategy.java b/presto-main/src/main/java/io/prestosql/operator/SimplePagesHashStrategy.java index dc0424a068226..de153273d730e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/SimplePagesHashStrategy.java +++ b/presto-main/src/main/java/io/prestosql/operator/SimplePagesHashStrategy.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.Block; @@ -70,8 +71,8 @@ public SimplePagesHashStrategy( requireNonNull(metadata, "metadata is null"); ImmutableList.Builder distinctFromMethodHandlesBuilder = ImmutableList.builder(); for (int i = 0; i < hashChannels.size(); i++) { - distinctFromMethodHandlesBuilder.add( - metadata.getScalarFunctionImplementation(metadata.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(types.get(i), types.get(i)))).getMethodHandle()); + ResolvedFunction resolvedFunction = metadata.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(types.get(i), types.get(i))); + distinctFromMethodHandlesBuilder.add(metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle()); } distinctFromMethodHandles = distinctFromMethodHandlesBuilder.build(); } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java index 5f32a6f4bc8cd..8eaf0a34909f0 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java @@ -20,6 +20,7 @@ import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlAggregationFunction; import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; @@ -106,7 +107,8 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, Metadata metadata) { Type type = boundVariables.getTypeVariable("E"); - MethodHandle compareMethodHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(operatorType, ImmutableList.of(type, type))).getMethodHandle(); + ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(type, type)); + MethodHandle compareMethodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); return generateAggregation(type, compareMethodHandle); } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java index e8e3818879967..ee778e84a67ba 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java @@ -27,6 +27,7 @@ import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlAggregationFunction; import io.prestosql.operator.aggregation.AccumulatorCompiler; @@ -156,7 +157,8 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key CallSiteBinder binder = new CallSiteBinder(); OperatorType operator = min ? LESS_THAN : GREATER_THAN; - MethodHandle compareMethod = metadata.getScalarFunctionImplementation(metadata.resolveOperator(operator, ImmutableList.of(keyType, keyType))).getMethodHandle(); + ResolvedFunction resolvedFunction = metadata.resolveOperator(operator, ImmutableList.of(keyType, keyType)); + MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), diff --git a/presto-main/src/main/java/io/prestosql/operator/index/FieldSetFilteringRecordSet.java b/presto-main/src/main/java/io/prestosql/operator/index/FieldSetFilteringRecordSet.java index f0921ad7ad841..6c80ff97083ca 100644 --- a/presto-main/src/main/java/io/prestosql/operator/index/FieldSetFilteringRecordSet.java +++ b/presto-main/src/main/java/io/prestosql/operator/index/FieldSetFilteringRecordSet.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.connector.RecordCursor; import io.prestosql.spi.connector.RecordSet; import io.prestosql.spi.type.Type; @@ -24,6 +25,7 @@ import java.lang.invoke.MethodHandle; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -51,9 +53,9 @@ public FieldSetFilteringRecordSet(Metadata metadata, RecordSet delegate, List fieldSet : requireNonNull(fieldSets, "fieldSets is null")) { ImmutableSet.Builder fieldSetBuilder = ImmutableSet.builder(); for (int field : fieldSet) { - fieldSetBuilder.add(new Field( - field, - metadata.getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(columnTypes.get(field), columnTypes.get(field)))).getMethodHandle())); + ResolvedFunction resolvedFunction = metadata.resolveOperator(EQUAL, ImmutableList.of(columnTypes.get(field), columnTypes.get(field))); + MethodHandle methodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); + fieldSetBuilder.add(new Field(field, methodHandle)); } fieldSetsBuilder.add(fieldSetBuilder.build()); } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java b/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java index c0ab8c71a8ccf..863e596a330d4 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/AbstractGreatestLeast.java @@ -27,6 +27,7 @@ import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlScalarFunction; import io.prestosql.spi.PrestoException; @@ -37,6 +38,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; @@ -97,7 +99,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type type = boundVariables.getTypeVariable("E"); checkArgument(type.isOrderable(), "Type must be orderable"); - MethodHandle compareMethod = metadata.getScalarFunctionImplementation(metadata.resolveOperator(operatorType, ImmutableList.of(type, type))).getMethodHandle(); + ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(type, type)); + MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); List> javaTypes = IntStream.range(0, arity) .mapToObj(i -> type.getJavaType()) diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayJoin.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayJoin.java index f734bbe90655d..c8d923f648df6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayJoin.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayJoin.java @@ -20,6 +20,7 @@ import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlScalarFunction; import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; @@ -168,7 +169,8 @@ private static ScalarFunctionImplementation specializeArrayJoin(Map elementType = type.getJavaType(); @@ -188,8 +190,6 @@ else if (elementType == Slice.class) { throw new UnsupportedOperationException("Unsupported type: " + elementType.getName()); } - MethodHandle cast = castFunction.getMethodHandle(); - // if the cast doesn't take a ConnectorSession, create an adapter that drops the provided session if (cast.type().parameterArray()[0] != ConnectorSession.class) { cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class); diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/FormatFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/FormatFunction.java index 7578a88bb7d5d..592b3492d2c98 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/FormatFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/FormatFunction.java @@ -44,6 +44,7 @@ import java.time.ZonedDateTime; import java.util.IllegalFormatException; import java.util.List; +import java.util.Optional; import java.util.function.BiFunction; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -193,7 +194,7 @@ private static BiFunction valueConverter(Metada // TODO: support TIME WITH TIME ZONE by https://github.com/prestosql/presto/issues/191 + mapping to java.time.OffsetTime if (type.equals(JSON)) { ResolvedFunction function = metadata.resolveFunction(QualifiedName.of("json_format"), fromTypes(JSON)); - MethodHandle handle = metadata.getScalarFunctionImplementation(function).getMethodHandle(); + MethodHandle handle = metadata.getScalarFunctionInvoker(function, Optional.empty()).getMethodHandle(); return (session, block) -> convertToString(handle, type.getSlice(block, position)); } if (isShortDecimal(type)) { @@ -241,7 +242,7 @@ private static MethodHandle castToVarchar(Metadata metadata, Type type) { try { ResolvedFunction cast = metadata.getCoercion(type, VARCHAR); - return metadata.getScalarFunctionImplementation(cast).getMethodHandle(); + return metadata.getScalarFunctionInvoker(cast, Optional.empty()).getMethodHandle(); } catch (OperatorNotFoundException e) { return null; diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/MapConstructor.java b/presto-main/src/main/java/io/prestosql/operator/scalar/MapConstructor.java index 4c36c0c54731d..7c47b896ec389 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/MapConstructor.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/MapConstructor.java @@ -28,11 +28,10 @@ import io.prestosql.spi.block.DuplicateMapKeyException; import io.prestosql.spi.block.MapBlockBuilder; import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.function.OperatorType; +import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.type.MapType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; -import io.prestosql.spi.type.TypeSignatureParameter; import java.lang.invoke.MethodHandle; import java.util.Optional; @@ -43,9 +42,14 @@ import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.prestosql.spi.function.OperatorType.EQUAL; +import static io.prestosql.spi.function.OperatorType.HASH_CODE; import static io.prestosql.spi.function.OperatorType.INDETERMINATE; import static io.prestosql.spi.type.StandardTypes.MAP; import static io.prestosql.spi.type.TypeSignature.arrayType; +import static io.prestosql.spi.type.TypeSignatureParameter.typeParameter; import static io.prestosql.spi.type.TypeUtils.readNativeValue; import static io.prestosql.util.Failures.checkCondition; import static io.prestosql.util.Failures.internalError; @@ -97,10 +101,11 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - Type mapType = metadata.getParameterizedType(MAP, ImmutableList.of(TypeSignatureParameter.typeParameter(keyType.getTypeSignature()), TypeSignatureParameter.typeParameter(valueType.getTypeSignature()))); - MethodHandle keyHashCode = metadata.getScalarFunctionImplementation(metadata.resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(keyType))).getMethodHandle(); - MethodHandle keyEqual = metadata.getScalarFunctionImplementation(metadata.resolveOperator(OperatorType.EQUAL, ImmutableList.of(keyType, keyType))).getMethodHandle(); - MethodHandle keyIndeterminate = metadata.getScalarFunctionImplementation(metadata.resolveOperator(INDETERMINATE, ImmutableList.of(keyType))).getMethodHandle(); + Type mapType = metadata.getParameterizedType(MAP, ImmutableList.of(typeParameter(keyType.getTypeSignature()), typeParameter(valueType.getTypeSignature()))); + MethodHandle keyHashCode = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(keyType)), Optional.empty()).getMethodHandle(); + MethodHandle keyEqual = metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(keyType, keyType)), Optional.empty()).getMethodHandle(); + InvocationConvention indeterminateCallingConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG), FAIL_ON_NULL, false, false); + MethodHandle keyIndeterminate = metadata.getScalarFunctionInvoker(metadata.resolveOperator(INDETERMINATE, ImmutableList.of(keyType)), Optional.of(indeterminateCallingConvention)).getMethodHandle(); MethodHandle instanceFactory = constructorMethodHandle(State.class, MapType.class).bindTo(mapType); return new ScalarFunctionImplementation( diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/MapElementAtFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/MapElementAtFunction.java index fe56e50223394..3176e0e21d33d 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/MapElementAtFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/MapElementAtFunction.java @@ -29,6 +29,7 @@ import io.prestosql.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.metadata.Signature.typeVariable; @@ -76,7 +77,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - MethodHandle keyEqualsMethod = metadata.getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(keyType, keyType))).getMethodHandle(); + MethodHandle keyEqualsMethod = metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(keyType, keyType)), Optional.empty()).getMethodHandle(); MethodHandle methodHandle; if (keyType.getJavaType() == boolean.class) { diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/MapHashCodeOperator.java b/presto-main/src/main/java/io/prestosql/operator/scalar/MapHashCodeOperator.java index 09930d08bfca7..55994cae215fd 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/MapHashCodeOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/MapHashCodeOperator.java @@ -23,6 +23,7 @@ import io.prestosql.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static io.prestosql.metadata.Signature.comparableTypeParameter; import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; @@ -55,8 +56,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - MethodHandle keyHashCodeFunction = metadata.getScalarFunctionImplementation(metadata.resolveOperator(HASH_CODE, ImmutableList.of(keyType))).getMethodHandle(); - MethodHandle valueHashCodeFunction = metadata.getScalarFunctionImplementation(metadata.resolveOperator(HASH_CODE, ImmutableList.of(valueType))).getMethodHandle(); + MethodHandle keyHashCodeFunction = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(keyType)), Optional.empty()).getMethodHandle(); + MethodHandle valueHashCodeFunction = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(valueType)), Optional.empty()).getMethodHandle(); MethodHandle method = METHOD_HANDLE.bindTo(keyHashCodeFunction).bindTo(valueHashCodeFunction).bindTo(keyType).bindTo(valueType); return new ScalarFunctionImplementation( diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/MapToMapCast.java b/presto-main/src/main/java/io/prestosql/operator/scalar/MapToMapCast.java index 96b57469662e5..9d80120bfea9f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/MapToMapCast.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/MapToMapCast.java @@ -18,6 +18,7 @@ import io.airlift.slice.Slice; import io.prestosql.annotation.UsedByGeneratedCode; import io.prestosql.metadata.BoundVariables; +import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.SqlOperator; @@ -32,6 +33,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -114,8 +116,8 @@ private MethodHandle buildProcessor(Metadata metadata, Type fromType, Type toTyp // Adapt cast that takes ([ConnectorSession,] ?) to one that takes (?, ConnectorSession), where ? is the return type of getter. ResolvedFunction resolvedFunction = metadata.getCoercion(fromType, toType); - ScalarFunctionImplementation castImplementation = metadata.getScalarFunctionImplementation(resolvedFunction); - MethodHandle cast = castImplementation.getMethodHandle(); + FunctionMetadata functionMetadata = metadata.getFunctionMetadata(resolvedFunction); + MethodHandle cast = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); if (cast.type().parameterArray()[0] != ConnectorSession.class) { cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class); } @@ -123,7 +125,7 @@ private MethodHandle buildProcessor(Metadata metadata, Type fromType, Type toTyp MethodHandle target = compose(cast, getter); // If the key cast function is nullable, check the result is not null. - if (isKey && castImplementation.isNullable()) { + if (isKey && functionMetadata.isNullable()) { target = compose(nullChecker(target.type().returnType()), target); } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/RowComparisonOperator.java b/presto-main/src/main/java/io/prestosql/operator/scalar/RowComparisonOperator.java index ac74d2cb0718f..c93763ec2d547 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/RowComparisonOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/RowComparisonOperator.java @@ -26,6 +26,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static io.prestosql.metadata.Signature.orderableWithVariadicBound; import static io.prestosql.spi.type.BooleanType.BOOLEAN; @@ -51,7 +52,7 @@ protected List getMethodHandles(RowType type, Metadata metadata, O ImmutableList.Builder argumentMethods = ImmutableList.builder(); for (Type parameterType : type.getTypeParameters()) { ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(parameterType, parameterType)); - argumentMethods.add(metadata.getScalarFunctionImplementation(resolvedFunction).getMethodHandle()); + argumentMethods.add(metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle()); } return argumentMethods.build(); } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/RowEqualOperator.java b/presto-main/src/main/java/io/prestosql/operator/scalar/RowEqualOperator.java index 5252c85fcedc6..b772de9446fcd 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/RowEqualOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/RowEqualOperator.java @@ -25,6 +25,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.comparableWithVariadicBound; @@ -76,8 +77,7 @@ public static List resolveFieldEqualOperators(RowType rowType, Met private static MethodHandle resolveEqualOperator(Type type, Metadata metadata) { ResolvedFunction operator = metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)); - ScalarFunctionImplementation implementation = metadata.getScalarFunctionImplementation(operator); - return implementation.getMethodHandle(); + return metadata.getScalarFunctionInvoker(operator, Optional.empty()).getMethodHandle(); } public static Boolean equals(RowType rowType, List fieldEqualOperators, Block leftRow, Block rightRow) diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ScalarFunctionImplementation.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ScalarFunctionImplementation.java index 085f25e899957..8cb4f010577d2 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/ScalarFunctionImplementation.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ScalarFunctionImplementation.java @@ -68,11 +68,6 @@ public ScalarFunctionImplementation(List choices) this.choices = ImmutableList.copyOf(choices); } - public boolean isNullable() - { - return choices.get(0).isNullable(); - } - public ArgumentProperty getArgumentProperty(int argumentIndex) { return getArgumentProperties().get(argumentIndex); diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/TryCastFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/TryCastFunction.java index b7ed9a0b371a6..83ba5fb5d86e9 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/TryCastFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/TryCastFunction.java @@ -28,10 +28,14 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static io.prestosql.metadata.FunctionKind.SCALAR; import static io.prestosql.metadata.Signature.castableToTypeParameter; import static io.prestosql.metadata.Signature.typeVariable; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; +import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_BOXED_TYPE; import static java.lang.invoke.MethodHandles.catchException; import static java.lang.invoke.MethodHandles.constant; import static java.lang.invoke.MethodHandles.dropArguments; @@ -67,19 +71,17 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type toType = boundVariables.getTypeVariable("T"); Class returnType = Primitives.wrap(toType.getJavaType()); - List argumentProperties; - MethodHandle tryCastHandle; // the resulting method needs to return a boxed type ResolvedFunction resolvedFunction = metadata.getCoercion(fromType, toType); - ScalarFunctionImplementation implementation = metadata.getScalarFunctionImplementation(resolvedFunction); - argumentProperties = ImmutableList.of(implementation.getArgumentProperty(0)); - MethodHandle coercion = implementation.getMethodHandle(); + MethodHandle coercion = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); coercion = coercion.asType(methodType(returnType, coercion.type())); MethodHandle exceptionHandler = dropArguments(constant(returnType, null), 0, RuntimeException.class); - tryCastHandle = catchException(coercion, RuntimeException.class, exceptionHandler); + MethodHandle tryCastHandle = catchException(coercion, RuntimeException.class, exceptionHandler); + boolean nullable = metadata.getFunctionMetadata(resolvedFunction).getArgumentDefinitions().get(0).isNullable(); + List argumentProperties = ImmutableList.of(nullable ? valueTypeArgumentProperty(USE_BOXED_TYPE) : valueTypeArgumentProperty(RETURN_NULL_ON_NULL)); return new ScalarFunctionImplementation(true, argumentProperties, tryCastHandle); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/InterpretedFunctionInvoker.java b/presto-main/src/main/java/io/prestosql/sql/InterpretedFunctionInvoker.java index 6fc2409910c49..e1aa918967f44 100644 --- a/presto-main/src/main/java/io/prestosql/sql/InterpretedFunctionInvoker.java +++ b/presto-main/src/main/java/io/prestosql/sql/InterpretedFunctionInvoker.java @@ -13,21 +13,28 @@ */ package io.prestosql.sql; -import com.google.common.base.Defaults; +import com.google.common.collect.ImmutableList; +import io.prestosql.metadata.FunctionInvoker; +import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.ResolvedFunction; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; -import io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.function.InvocationConvention; +import io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention; +import io.prestosql.type.FunctionType; import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import static com.google.common.base.Throwables.throwIfUnchecked; -import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentType.VALUE_TYPE; -import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.USE_NULL_FLAG; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static java.lang.invoke.MethodHandleProxies.asInterfaceInstance; import static java.util.Objects.requireNonNull; @@ -52,36 +59,41 @@ public Object invoke(ResolvedFunction function, ConnectorSession session, Object */ public Object invoke(ResolvedFunction function, ConnectorSession session, List arguments) { - ScalarFunctionImplementation implementation = metadata.getScalarFunctionImplementation(function); - MethodHandle method = implementation.getMethodHandle(); + FunctionMetadata functionMetadata = metadata.getFunctionMetadata(function); + FunctionInvoker invoker = metadata.getScalarFunctionInvoker(function, Optional.of(getInvocationConvention(function, functionMetadata))); + MethodHandle method = invoker.getMethodHandle(); + + List actualArguments = new ArrayList<>(); // handle function on instance method, to allow use of fields - method = bindInstanceFactory(method, implementation); + if (invoker.getInstanceFactory().isPresent()) { + try { + actualArguments.add(invoker.getInstanceFactory().get().invoke()); + } + catch (Throwable throwable) { + throw propagate(throwable); + } + } - if (method.type().parameterCount() > 0 && method.type().parameterType(0) == ConnectorSession.class) { - method = method.bindTo(session); + // add session + if (method.type().parameterCount() > actualArguments.size() && method.type().parameterType(actualArguments.size()) == ConnectorSession.class) { + actualArguments.add(session); } - List actualArguments = new ArrayList<>(); + for (int i = 0; i < arguments.size(); i++) { Object argument = arguments.get(i); - ArgumentProperty argumentProperty = implementation.getArgumentProperty(i); - if (argumentProperty.getArgumentType() == VALUE_TYPE) { - if (implementation.getArgumentProperty(i).getNullConvention() == USE_NULL_FLAG) { - boolean isNull = argument == null; - if (isNull) { - argument = Defaults.defaultValue(method.type().parameterType(actualArguments.size())); - } - actualArguments.add(argument); - actualArguments.add(isNull); - } - else { - actualArguments.add(argument); - } + + // if argument is null and function does not handle nulls, result is null + if (argument == null && !functionMetadata.getArgumentDefinitions().get(i).isNullable()) { + return null; } - else { - argument = asInterfaceInstance(method.type().parameterType(i), (MethodHandle) argument); - actualArguments.add(argument); + + Optional> lambdaInterface = invoker.getLambdaInterfaces().get(i); + if (lambdaInterface.isPresent()) { + argument = asInterfaceInstance(lambdaInterface.get(), (MethodHandle) argument); } + + actualArguments.add(argument); } try { @@ -92,18 +104,26 @@ public Object invoke(ResolvedFunction function, ConnectorSession session, List argumentConventions = ImmutableList.builder(); + for (int i = 0; i < functionMetadata.getArgumentDefinitions().size(); i++) { + if (function.getSignature().getArgumentTypes().get(i).getBase().equalsIgnoreCase(FunctionType.NAME)) { + argumentConventions.add(FUNCTION); + } + else if (functionMetadata.getArgumentDefinitions().get(i).isNullable()) { + argumentConventions.add(BOXED_NULLABLE); + } + else { + argumentConventions.add(NEVER_NULL); + } } - try { - return method.bindTo(implementation.getInstanceFactory().get().invoke()); - } - catch (Throwable throwable) { - throw propagate(throwable); - } + return new InvocationConvention( + argumentConventions.build(), + functionMetadata.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, + true, + true); } private static RuntimeException propagate(Throwable throwable) diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java index 3d08d714400e2..7e87ae36eec99 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/InCodeGenerator.java @@ -26,7 +26,7 @@ import io.airlift.bytecode.instruction.LabelNode; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.ResolvedFunction; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; +import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.function.OperatorType; import io.prestosql.spi.type.BigintType; import io.prestosql.spi.type.DateType; @@ -40,6 +40,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -48,6 +49,8 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; import static io.airlift.bytecode.instruction.JumpInstruction.jump; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.prestosql.spi.function.OperatorType.HASH_CODE; import static io.prestosql.spi.function.OperatorType.INDETERMINATE; import static io.prestosql.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; @@ -121,9 +124,10 @@ public BytecodeNode generateExpression(ResolvedFunction resolvedFunction, Byteco Metadata metadata = generatorContext.getMetadata(); ResolvedFunction resolvedHashCodeFunction = metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)); - MethodHandle hashCodeFunction = metadata.getScalarFunctionImplementation(resolvedHashCodeFunction).getMethodHandle(); + MethodHandle hashCodeFunction = metadata.getScalarFunctionInvoker(resolvedHashCodeFunction, Optional.empty()).getMethodHandle(); ResolvedFunction resolvedIsIndeterminate = metadata.resolveOperator(INDETERMINATE, ImmutableList.of(type)); - ScalarFunctionImplementation isIndeterminateFunction = metadata.getScalarFunctionImplementation(resolvedIsIndeterminate); + InvocationConvention indeterminateCallingConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG), FAIL_ON_NULL, false, false); + MethodHandle isIndeterminateFunction = metadata.getScalarFunctionInvoker(resolvedIsIndeterminate, Optional.of(indeterminateCallingConvention)).getMethodHandle(); ImmutableListMultimap.Builder hashBucketsBuilder = ImmutableListMultimap.builder(); ImmutableList.Builder defaultBucket = ImmutableList.builder(); @@ -132,7 +136,7 @@ public BytecodeNode generateExpression(ResolvedFunction resolvedFunction, Byteco for (RowExpression testValue : values) { BytecodeNode testBytecode = generatorContext.generate(testValue); - if (isDeterminateConstant(testValue, isIndeterminateFunction.getMethodHandle())) { + if (isDeterminateConstant(testValue, isIndeterminateFunction)) { ConstantExpression constant = (ConstantExpression) testValue; Object object = constant.getValue(); switch (switchGenerationCase) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java index cd04d1c794960..99d130b3dd421 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionInterpreter.java @@ -678,7 +678,7 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con return value; case MINUS: ResolvedFunction resolvedOperator = metadata.resolveOperator(OperatorType.NEGATION, types(node.getValue())); - MethodHandle handle = metadata.getScalarFunctionImplementation(resolvedOperator).getMethodHandle(); + MethodHandle handle = metadata.getScalarFunctionInvoker(resolvedOperator, Optional.empty()).getMethodHandle(); if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) { handle = handle.bindTo(session); diff --git a/presto-main/src/main/java/io/prestosql/sql/relational/optimizer/ExpressionOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/relational/optimizer/ExpressionOptimizer.java index d859c2772f066..c89ec34d43bb0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/relational/optimizer/ExpressionOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/relational/optimizer/ExpressionOptimizer.java @@ -14,11 +14,13 @@ package io.prestosql.sql.relational.optimizer; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; import io.prestosql.Session; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.function.InvocationConvention; +import io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.MapType; import io.prestosql.spi.type.RowType; @@ -32,19 +34,25 @@ import io.prestosql.sql.relational.SpecialForm; import io.prestosql.sql.relational.VariableReferenceExpression; import io.prestosql.sql.tree.QualifiedName; +import io.prestosql.type.FunctionType; import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Predicates.instanceOf; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.Signature.mangleOperatorName; import static io.prestosql.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; import static io.prestosql.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; import static io.prestosql.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION; +import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.prestosql.spi.function.OperatorType.CAST; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.VarcharType.VARCHAR; @@ -98,15 +106,16 @@ public RowExpression visitCall(CallExpression call, Void context) // TODO: optimize function calls with lambda arguments. For example, apply(x -> x + 2, 1) FunctionMetadata functionMetadata = metadata.getFunctionMetadata(call.getResolvedFunction()); - if (Iterables.all(arguments, instanceOf(ConstantExpression.class)) && functionMetadata.isDeterministic()) { - MethodHandle method = metadata.getScalarFunctionImplementation(call.getResolvedFunction()).getMethodHandle(); + if (arguments.stream().allMatch(ConstantExpression.class::isInstance) && functionMetadata.isDeterministic()) { + InvocationConvention convention = getInvocationConvention(call.getResolvedFunction(), functionMetadata); + MethodHandle method = metadata.getScalarFunctionInvoker(call.getResolvedFunction(), Optional.of(convention)).getMethodHandle(); + List constantArguments = new ArrayList<>(); if (method.type().parameterCount() > 0 && method.type().parameterType(0) == ConnectorSession.class) { - method = method.bindTo(session); + constantArguments.add(session); } int index = 0; - List constantArguments = new ArrayList<>(); for (RowExpression argument : arguments) { Object value = ((ConstantExpression) argument).getValue(); // if any argument is null, return null @@ -131,6 +140,28 @@ public RowExpression visitCall(CallExpression call, Void context) return call(call.getResolvedFunction(), metadata.getType(call.getResolvedFunction().getSignature().getReturnType()), arguments); } + private InvocationConvention getInvocationConvention(ResolvedFunction function, FunctionMetadata functionMetadata) + { + ImmutableList.Builder argumentConventions = ImmutableList.builder(); + for (int i = 0; i < functionMetadata.getArgumentDefinitions().size(); i++) { + if (function.getSignature().getArgumentTypes().get(i).getBase().equalsIgnoreCase(FunctionType.NAME)) { + argumentConventions.add(FUNCTION); + } + else if (functionMetadata.getArgumentDefinitions().get(i).isNullable()) { + argumentConventions.add(BOXED_NULLABLE); + } + else { + argumentConventions.add(NEVER_NULL); + } + } + + return new InvocationConvention( + argumentConventions.build(), + functionMetadata.isNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, + true, + true); + } + @Override public RowExpression visitSpecialForm(SpecialForm specialForm, Void context) { diff --git a/presto-main/src/main/java/io/prestosql/type/InternalTypeManager.java b/presto-main/src/main/java/io/prestosql/type/InternalTypeManager.java index 844a408ae41b9..8ace80de37948 100644 --- a/presto-main/src/main/java/io/prestosql/type/InternalTypeManager.java +++ b/presto-main/src/main/java/io/prestosql/type/InternalTypeManager.java @@ -14,7 +14,6 @@ package io.prestosql.type; import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.function.OperatorType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeId; @@ -25,6 +24,7 @@ import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -60,7 +60,6 @@ public Type getType(TypeId id) @Override public MethodHandle resolveOperator(OperatorType operatorType, List argumentTypes) { - ResolvedFunction signature = metadata.resolveOperator(operatorType, argumentTypes); - return metadata.getScalarFunctionImplementation(signature).getMethodHandle(); + return metadata.getScalarFunctionInvoker(metadata.resolveOperator(operatorType, argumentTypes), Optional.empty()).getMethodHandle(); } } diff --git a/presto-main/src/main/java/io/prestosql/util/FastutilSetHelper.java b/presto-main/src/main/java/io/prestosql/util/FastutilSetHelper.java index c20be518c5e7e..f40a3f2f29970 100644 --- a/presto-main/src/main/java/io/prestosql/util/FastutilSetHelper.java +++ b/presto-main/src/main/java/io/prestosql/util/FastutilSetHelper.java @@ -28,6 +28,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodType; import java.util.Collection; +import java.util.Optional; import java.util.Set; import static com.google.common.base.Throwables.throwIfInstanceOf; @@ -93,8 +94,8 @@ private static final class LongStrategy private LongStrategy(Metadata metadata, Type type) { - hashCodeHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type))).getMethodHandle(); - equalsHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type))).getMethodHandle(); + hashCodeHandle = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), Optional.empty()).getMethodHandle(); + equalsHandle = metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), Optional.empty()).getMethodHandle(); } @Override @@ -135,8 +136,8 @@ private static final class DoubleStrategy private DoubleStrategy(Metadata metadata, Type type) { - hashCodeHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type))).getMethodHandle(); - equalsHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type))).getMethodHandle(); + hashCodeHandle = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), Optional.empty()).getMethodHandle(); + equalsHandle = metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), Optional.empty()).getMethodHandle(); } @Override @@ -177,10 +178,10 @@ private static final class ObjectStrategy private ObjectStrategy(Metadata metadata, Type type) { - hashCodeHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type))) + hashCodeHandle = metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), Optional.empty()) .getMethodHandle() .asType(MethodType.methodType(long.class, Object.class)); - equalsHandle = metadata.getScalarFunctionImplementation(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type))) + equalsHandle = metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), Optional.empty()) .getMethodHandle() .asType(MethodType.methodType(Boolean.class, Object.class, Object.class)); } diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index 52e8f6e822137..2f958c29001bc 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -19,7 +19,6 @@ import io.prestosql.Session; import io.prestosql.connector.CatalogName; import io.prestosql.operator.aggregation.InternalAggregationFunction; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; import io.prestosql.operator.window.WindowFunctionSupplier; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.BlockEncoding; @@ -644,12 +643,6 @@ public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunctio throw new UnsupportedOperationException(); } - @Override - public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction) - { - throw new UnsupportedOperationException(); - } - @Override public ProcedureRegistry getProcedureRegistry() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java index 564795fc7d1da..b7b9bd5637ee5 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java @@ -23,19 +23,20 @@ import io.prestosql.Session; import io.prestosql.connector.CatalogName; import io.prestosql.metadata.AbstractMockMetadata; +import io.prestosql.metadata.FunctionInvoker; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.TableHandle; import io.prestosql.metadata.TableProperties; -import io.prestosql.operator.scalar.ScalarFunctionImplementation; import io.prestosql.security.AllowAllAccessControl; import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.block.SortOrder; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorTableHandle; import io.prestosql.spi.connector.ConnectorTableProperties; +import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.Type; @@ -174,9 +175,9 @@ public ResolvedFunction getCoercion(Type fromType, Type toType) } @Override - public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction) + public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional invocationConvention) { - return delegate.getScalarFunctionImplementation(resolvedFunction); + return delegate.getScalarFunctionInvoker(resolvedFunction, invocationConvention); } @Override