Skip to content

Commit

Permalink
Remove unnecessary calls to resolve operator
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Oct 7, 2020
1 parent cb3662c commit 46a9c57
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io.airlift.units.DataSize;
import io.prestosql.PagesIndexPageSorter;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.operator.PagesIndex;
import io.prestosql.plugin.hive.authentication.NoHdfsAuthentication;
import io.prestosql.plugin.hive.azure.HiveAzureConfig;
Expand All @@ -43,7 +42,6 @@
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.NamedTypeSignature;
Expand All @@ -62,15 +60,14 @@
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
import static io.prestosql.metadata.MetadataManager.createTestMetadataManager;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.OperatorType.IS_DISTINCT_FROM;
import static io.prestosql.spi.function.InvocationConvention.simpleConvention;
import static io.prestosql.spi.type.Decimals.encodeScaledValue;

public final class HiveTestUtils
Expand Down Expand Up @@ -233,9 +230,7 @@ public static Slice longDecimal(String value)

public static MethodHandle distinctFromOperator(Type type)
{
ResolvedFunction function = METADATA.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(type, type));
InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG, NULL_FLAG), FAIL_ON_NULL, false, false);
return METADATA.getScalarFunctionInvoker(function, Optional.of(invocationConvention)).getMethodHandle();
return TYPE_MANAGER.getTypeOperators().getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, NULL_FLAG, NULL_FLAG));
}

public static boolean isDistinctFrom(MethodHandle handle, Block left, Block right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.connector.RecordCursor;
import io.prestosql.spi.connector.RecordSet;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeOperators;

import java.lang.invoke.MethodHandle;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.prestosql.spi.function.OperatorType.EQUAL;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.prestosql.spi.function.InvocationConvention.simpleConvention;
import static java.lang.Boolean.TRUE;
import static java.util.Objects.requireNonNull;

Expand All @@ -43,18 +43,16 @@ public class FieldSetFilteringRecordSet
private final RecordSet delegate;
private final List<Set<Field>> fieldSets;

public FieldSetFilteringRecordSet(Metadata metadata, RecordSet delegate, List<Set<Integer>> fieldSets)
public FieldSetFilteringRecordSet(TypeOperators typeOperators, RecordSet delegate, List<Set<Integer>> fieldSets)
{
requireNonNull(metadata, "metadata is null");
this.delegate = requireNonNull(delegate, "delegate is null");

ImmutableList.Builder<Set<Field>> fieldSetsBuilder = ImmutableList.builder();
List<Type> columnTypes = delegate.getColumnTypes();
for (Set<Integer> fieldSet : requireNonNull(fieldSets, "fieldSets is null")) {
ImmutableSet.Builder<Field> fieldSetBuilder = ImmutableSet.builder();
for (int field : fieldSet) {
ResolvedFunction resolvedFunction = metadata.resolveOperator(EQUAL, ImmutableList.of(columnTypes.get(field), columnTypes.get(field)));
MethodHandle methodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();
MethodHandle methodHandle = typeOperators.getEqualOperator(columnTypes.get(field), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL));
fieldSetBuilder.add(new Field(field, methodHandle));
}
fieldSetsBuilder.add(fieldSetBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1173,10 +1173,7 @@ protected Type visitBetweenPredicate(BetweenPredicate node, StackableAstVisitorC
semanticException(TYPE_MISMATCH, node, "Cannot check if %s is BETWEEN %s and %s", valueType, minType, maxType);
}

try {
metadata.resolveOperator(OperatorType.LESS_THAN_OR_EQUAL, List.of(commonType.get(), commonType.get()));
}
catch (OperatorNotFoundException e) {
if (!commonType.get().isOrderable()) {
semanticException(TYPE_MISMATCH, node, "Cannot check if %s is BETWEEN %s and %s", valueType, minType, maxType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@
import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic;
import static io.airlift.bytecode.instruction.JumpInstruction.jump;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.prestosql.spi.function.InvocationConvention.simpleConvention;
import static io.prestosql.spi.function.OperatorType.EQUAL;
import static io.prestosql.spi.function.OperatorType.HASH_CODE;
import static io.prestosql.spi.function.OperatorType.INDETERMINATE;
Expand Down Expand Up @@ -129,8 +132,8 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext

SwitchGenerationCase switchGenerationCase = checkSwitchGenerationCase(type, testExpressions);

FunctionInvoker equalsInvoker = generatorContext.getScalarFunctionInvoker(resolvedEqualsFunction, Optional.empty());
FunctionInvoker hashCodeInvoker = generatorContext.getScalarFunctionInvoker(resolvedHashCodeFunction, Optional.empty());
MethodHandle equalsMethodHandle = generatorContext.getScalarFunctionInvoker(resolvedEqualsFunction, Optional.of(simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL))).getMethodHandle();
MethodHandle hashCodeMethodHandle = generatorContext.getScalarFunctionInvoker(resolvedHashCodeFunction, Optional.of(simpleConvention(FAIL_ON_NULL, NEVER_NULL))).getMethodHandle();
InvocationConvention indeterminateCallingConvention = new InvocationConvention(ImmutableList.of(NULL_FLAG), FAIL_ON_NULL, false, false);
FunctionInvoker indeterminateInvoker = generatorContext.getScalarFunctionInvoker(resolvedIsIndeterminate, Optional.of(indeterminateCallingConvention));

Expand All @@ -151,7 +154,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext
break;
case HASH_SWITCH:
try {
int hashCode = Long.hashCode((Long) hashCodeInvoker.getMethodHandle().invoke(object));
int hashCode = Long.hashCode((Long) hashCodeMethodHandle.invoke(object));
hashBucketsBuilder.put(hashCode, testBytecode);
}
catch (Throwable throwable) {
Expand Down Expand Up @@ -217,7 +220,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext
switchBuilder.defaultCase(jump(defaultLabel));
Binding hashCodeBinding = generatorContext
.getCallSiteBinder()
.bind(hashCodeInvoker.getMethodHandle());
.bind(hashCodeMethodHandle);
switchBlock = new BytecodeBlock()
.comment("lookupSwitch(hashCode(<stackValue>))")
.getVariable(value)
Expand All @@ -227,7 +230,7 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext
.append(switchBuilder.build());
break;
case SET_CONTAINS:
Set<?> constantValuesSet = toFastutilHashSet(constantValues, type, hashCodeInvoker, equalsInvoker);
Set<?> constantValuesSet = toFastutilHashSet(constantValues, type, hashCodeMethodHandle, equalsMethodHandle);
Binding constant = generatorContext.getCallSiteBinder().bind(constantValuesSet, constantValuesSet.getClass());

switchBlock = new BytecodeBlock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.prestosql.spi.StandardErrorCode.TYPE_MISMATCH;
import static io.prestosql.spi.function.OperatorType.EQUAL;
import static io.prestosql.spi.function.OperatorType.HASH_CODE;
import static io.prestosql.spi.type.TypeUtils.readNativeValue;
import static io.prestosql.spi.type.TypeUtils.writeNativeValue;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
Expand Down Expand Up @@ -575,7 +577,12 @@ protected Object visitInPredicate(InPredicate node, Object context)
if (valueList.getValues().stream().allMatch(Literal.class::isInstance) &&
valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) {
Set<Object> objectSet = valueList.getValues().stream().map(expression -> process(expression, context)).collect(Collectors.toSet());
set = FastutilSetHelper.toFastutilHashSet(objectSet, type(node.getValue()), metadata);
Type type = type(node.getValue());
set = FastutilSetHelper.toFastutilHashSet(
objectSet,
type,
metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), Optional.empty()).getMethodHandle(),
metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), Optional.empty()).getMethodHandle());
}
inListCache.put(valueList, set);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.predicate.NullableValue;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeOperators;
import io.prestosql.spiller.PartitioningSpillerFactory;
import io.prestosql.spiller.SingleStreamSpillerFactory;
import io.prestosql.spiller.SpillerFactory;
Expand Down Expand Up @@ -322,6 +323,7 @@ public class LocalExecutionPlanner
private final LookupJoinOperators lookupJoinOperators;
private final OrderingCompiler orderingCompiler;
private final DynamicFilterConfig dynamicFilterConfig;
private final TypeOperators typeOperators;

@Inject
public LocalExecutionPlanner(
Expand All @@ -345,7 +347,8 @@ public LocalExecutionPlanner(
JoinCompiler joinCompiler,
LookupJoinOperators lookupJoinOperators,
OrderingCompiler orderingCompiler,
DynamicFilterConfig dynamicFilterConfig)
DynamicFilterConfig dynamicFilterConfig,
TypeOperators typeOperators)
{
this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null");
this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null");
Expand All @@ -371,6 +374,7 @@ public LocalExecutionPlanner(
this.lookupJoinOperators = requireNonNull(lookupJoinOperators, "lookupJoinOperators is null");
this.orderingCompiler = requireNonNull(orderingCompiler, "orderingCompiler is null");
this.dynamicFilterConfig = requireNonNull(dynamicFilterConfig, "dynamicFilterConfig is null");
this.typeOperators = requireNonNull(typeOperators, "typeOperators is null");
}

public LocalExecutionPlan plan(
Expand Down Expand Up @@ -1521,7 +1525,7 @@ public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPl
List<Integer> remappedProbeKeyChannels = remappedProbeKeyChannelsBuilder.build();
Function<RecordSet, RecordSet> probeKeyNormalizer = recordSet -> {
if (!overlappingFieldSets.isEmpty()) {
recordSet = new FieldSetFilteringRecordSet(metadata, recordSet, overlappingFieldSets);
recordSet = new FieldSetFilteringRecordSet(typeOperators, recordSet, overlappingFieldSets);
}
return new MappedRecordSet(recordSet, remappedProbeKeyChannels);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,8 @@ private List<Driver> createDrivers(Session session, Plan plan, OutputFactory out
joinCompiler,
new LookupJoinOperators(),
new OrderingCompiler(),
new DynamicFilterConfig());
new DynamicFilterConfig(),
typeOperators);

// plan query
StageExecutionDescriptor stageExecutionDescriptor = subplan.getFragment().getStageExecutionDescriptor();
Expand Down
48 changes: 14 additions & 34 deletions presto-main/src/main/java/io/prestosql/util/FastutilSetHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
*/
package io.prestosql.util;

import com.google.common.collect.ImmutableList;
import io.prestosql.metadata.FunctionInvoker;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.type.Type;
import it.unimi.dsi.fastutil.Hash;
Expand All @@ -27,58 +24,41 @@
import it.unimi.dsi.fastutil.objects.ObjectOpenCustomHashSet;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.Collection;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.base.Verify.verifyNotNull;
import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.prestosql.spi.function.OperatorType.EQUAL;
import static io.prestosql.spi.function.OperatorType.HASH_CODE;
import static java.lang.Boolean.TRUE;
import static java.lang.invoke.MethodType.methodType;
import static java.util.Objects.requireNonNull;

public final class FastutilSetHelper
{
private FastutilSetHelper() {}

public static Set<?> toFastutilHashSet(Set<?> set, Type type, Metadata metadata)
{
return toFastutilHashSet(
set,
type,
metadata.getScalarFunctionInvoker(metadata.resolveOperator(HASH_CODE, ImmutableList.of(type)), Optional.empty()),
metadata.getScalarFunctionInvoker(metadata.resolveOperator(EQUAL, ImmutableList.of(type, type)), Optional.empty()));
}

@SuppressWarnings("unchecked")
public static Set<?> toFastutilHashSet(Set<?> set, Type type, FunctionInvoker hashCode, FunctionInvoker equals)
public static Set<?> toFastutilHashSet(Set<?> set, Type type, MethodHandle hashCodeHandle, MethodHandle equalsHandle)
{
requireNonNull(set, "set is null");
requireNonNull(type, "type is null");
requireNonNull(hashCode, "hashCode is null");
checkArgument(hashCode.getInstanceFactory().isEmpty(), "hashCode method has instance factory");
requireNonNull(equals, "equals is null");
checkArgument(equals.getInstanceFactory().isEmpty(), "equals method has instance factory");

// 0.25 as the load factor is chosen because the argument set is assumed to be small (<10000),
// and the return set is assumed to be read-heavy.
// The performance of InCodeGenerator heavily depends on the load factor being small.
Class<?> javaElementType = type.getJavaType();
if (javaElementType == long.class) {
return new LongOpenCustomHashSet((Collection<Long>) set, 0.25f, new LongStrategy(hashCode, equals));
return new LongOpenCustomHashSet((Collection<Long>) set, 0.25f, new LongStrategy(hashCodeHandle, equalsHandle));
}
if (javaElementType == double.class) {
return new DoubleOpenCustomHashSet((Collection<Double>) set, 0.25f, new DoubleStrategy(hashCode, equals));
return new DoubleOpenCustomHashSet((Collection<Double>) set, 0.25f, new DoubleStrategy(hashCodeHandle, equalsHandle));
}
if (javaElementType == boolean.class) {
return new BooleanOpenHashSet((Collection<Boolean>) set, 0.25f);
}
else if (!type.getJavaType().isPrimitive()) {
return new ObjectOpenCustomHashSet<>(set, 0.25f, new ObjectStrategy(hashCode, equals));
return new ObjectOpenCustomHashSet<>(set, 0.25f, new ObjectStrategy(hashCodeHandle, equalsHandle));
}
else {
throw new UnsupportedOperationException("Unsupported native type in set: " + type.getJavaType() + " with type " + type.getTypeSignature());
Expand Down Expand Up @@ -111,10 +91,10 @@ private static final class LongStrategy
private final MethodHandle hashCodeHandle;
private final MethodHandle equalsHandle;

private LongStrategy(FunctionInvoker hashCode, FunctionInvoker equals)
public LongStrategy(MethodHandle hashCodeHandle, MethodHandle equalsHandle)
{
hashCodeHandle = hashCode.getMethodHandle();
equalsHandle = equals.getMethodHandle();
this.hashCodeHandle = requireNonNull(hashCodeHandle, "hashCodeHandle is null");
this.equalsHandle = requireNonNull(equalsHandle, "equalsHandle is null");
}

@Override
Expand Down Expand Up @@ -153,10 +133,10 @@ private static final class DoubleStrategy
private final MethodHandle hashCodeHandle;
private final MethodHandle equalsHandle;

private DoubleStrategy(FunctionInvoker hashCode, FunctionInvoker equals)
public DoubleStrategy(MethodHandle hashCodeHandle, MethodHandle equalsHandle)
{
hashCodeHandle = hashCode.getMethodHandle();
equalsHandle = equals.getMethodHandle();
this.hashCodeHandle = requireNonNull(hashCodeHandle, "hashCodeHandle is null");
this.equalsHandle = requireNonNull(equalsHandle, "equalsHandle is null");
}

@Override
Expand Down Expand Up @@ -195,10 +175,10 @@ private static final class ObjectStrategy
private final MethodHandle hashCodeHandle;
private final MethodHandle equalsHandle;

private ObjectStrategy(FunctionInvoker hashCode, FunctionInvoker equals)
public ObjectStrategy(MethodHandle hashCodeHandle, MethodHandle equalsHandle)
{
hashCodeHandle = hashCode.getMethodHandle().asType(MethodType.methodType(long.class, Object.class));
equalsHandle = equals.getMethodHandle().asType(MethodType.methodType(Boolean.class, Object.class, Object.class));
this.hashCodeHandle = requireNonNull(hashCodeHandle, "hashCodeHandle is null").asType(methodType(long.class, Object.class));
this.equalsHandle = requireNonNull(equalsHandle, "equalsHandle is null").asType(methodType(Boolean.class, Object.class, Object.class));
}

@Override
Expand Down
Loading

0 comments on commit 46a9c57

Please sign in to comment.