Skip to content

Commit

Permalink
Add FunctionCallBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Jun 23, 2019
1 parent 0a0689e commit 9f78af6
Show file tree
Hide file tree
Showing 55 changed files with 821 additions and 309 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import io.prestosql.split.SplitSource;
import io.prestosql.sql.gen.PageFunctionCompiler;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.HashGenerationOptimizer;
Expand Down Expand Up @@ -215,21 +216,24 @@ private static List<Split> getNextBatch(SplitSource splitSource)

protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNodeId planNodeId, List<Type> types)
{
ImmutableMap.Builder<Symbol, Type> symbolTypes = ImmutableMap.builder();
SymbolAllocator symbolAllocator = new SymbolAllocator();
ImmutableMap.Builder<Symbol, Integer> symbolToInputMapping = ImmutableMap.builder();
ImmutableList.Builder<PageProjection> projections = ImmutableList.builder();
for (int channel = 0; channel < types.size(); channel++) {
Symbol symbol = new Symbol("h" + channel);
symbolTypes.put(symbol, types.get(channel));
Symbol symbol = symbolAllocator.newSymbol("h" + channel, types.get(channel));
symbolToInputMapping.put(symbol, channel);
projections.add(new InputPageProjection(channel, types.get(channel)));
}

Optional<Expression> hashExpression = HashGenerationOptimizer.getHashExpression(ImmutableList.copyOf(symbolTypes.build().keySet()));
Map<Symbol, Type> symbolTypes = symbolAllocator.getTypes().allTypes();
Optional<Expression> hashExpression = HashGenerationOptimizer.getHashExpression(
localQueryRunner.getMetadata(),
symbolAllocator,
ImmutableList.copyOf(symbolTypes.keySet()));
verify(hashExpression.isPresent());

Map<NodeRef<Expression>, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata())
.getTypes(session, TypeProvider.copyOf(symbolTypes.build()), hashExpression.get());
.getTypes(session, TypeProvider.copyOf(symbolTypes), hashExpression.get());

RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata(), session, false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public FilterStatsCalculator(Metadata metadata, ScalarStatsCalculator scalarStat
this.metadata = requireNonNull(metadata, "metadata is null");
this.scalarStatsCalculator = requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
this.normalizer = requireNonNull(normalizer, "normalizer is null");
this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde());
this.literalEncoder = new LiteralEncoder(metadata);
}

public PlanNodeStatsEstimate filterStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@

import java.util.function.Supplier;

import static io.prestosql.operator.scalar.TryFunction.NAME;
import static io.prestosql.spi.StandardErrorCode.DIVISION_BY_ZERO;
import static io.prestosql.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.prestosql.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;

@Description("internal try function for desugaring TRY")
@ScalarFunction(value = "$internal$try", hidden = true, deterministic = false)
@ScalarFunction(value = NAME, hidden = true, deterministic = false)
public final class TryFunction
{
public static final String NAME = "$internal$try";

private TryFunction() {}

@TypeParameter("T")
Expand Down
12 changes: 10 additions & 2 deletions presto-main/src/main/java/io/prestosql/sql/DynamicFilters.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.function.ScalarFunction;
import io.prestosql.spi.function.SqlType;
import io.prestosql.spi.function.TypeParameter;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.planner.FunctionCallBuilder;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.QualifiedName;
Expand All @@ -40,9 +44,13 @@ public final class DynamicFilters
{
private DynamicFilters() {}

public static Expression createDynamicFilterExpression(String id, SymbolReference input)
public static Expression createDynamicFilterExpression(Metadata metadata, String id, Type inputType, SymbolReference input)
{
return new FunctionCall(QualifiedName.of(Function.NAME), ImmutableList.of(new StringLiteral(id), input));
return new FunctionCallBuilder(metadata)
.setName(QualifiedName.of(Function.NAME))
.addArgument(VarcharType.VARCHAR, new StringLiteral(id))
.addArgument(inputType, input)
.build();
}

public static ExtractResult extractDynamicFilters(Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.prestosql.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
Expand All @@ -23,7 +22,6 @@
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SymbolReference;
Expand All @@ -38,9 +36,9 @@

public class DesugarAtTimeZoneRewriter
{
public static Expression rewrite(Expression expression, Map<NodeRef<Expression>, Type> expressionTypes)
public static Expression rewrite(Expression expression, Map<NodeRef<Expression>, Type> expressionTypes, Metadata metadata)
{
return ExpressionTreeRewriter.rewriteWith(new Visitor(expressionTypes), expression);
return ExpressionTreeRewriter.rewriteWith(new Visitor(expressionTypes, metadata), expression);
}

private DesugarAtTimeZoneRewriter() {}
Expand All @@ -55,32 +53,44 @@ public static Expression rewrite(Expression expression, Session session, Metadat
}
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression);

return rewrite(expression, expressionTypes);
return rewrite(expression, expressionTypes, metadata);
}

private static class Visitor
extends ExpressionRewriter<Void>
{
private final Map<NodeRef<Expression>, Type> expressionTypes;
private final Metadata metadata;

public Visitor(Map<NodeRef<Expression>, Type> expressionTypes)
public Visitor(Map<NodeRef<Expression>, Type> expressionTypes, Metadata metadata)
{
this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null"));
this.metadata = metadata;
}

@Override
public Expression rewriteAtTimeZone(AtTimeZone node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Type valueType = getType(node.getValue());
Expression value = treeRewriter.rewrite(node.getValue(), context);
Type type = getType(node.getValue());
if (type.equals(TIME)) {

if (valueType.equals(TIME)) {
valueType = TIME_WITH_TIME_ZONE;
value = new Cast(value, TIME_WITH_TIME_ZONE.getDisplayName());
}
else if (type.equals(TIMESTAMP)) {
else if (valueType.equals(TIMESTAMP)) {
valueType = TIMESTAMP_WITH_TIME_ZONE;
value = new Cast(value, TIMESTAMP_WITH_TIME_ZONE.getDisplayName());
}

return new FunctionCall(QualifiedName.of("at_timezone"), ImmutableList.of(value, treeRewriter.rewrite(node.getTimeZone(), context)));
Type timeZoneType = getType(node.getTimeZone());
Expression timeZone = treeRewriter.rewrite(node.getTimeZone(), context);

return new FunctionCallBuilder(metadata)
.setName(QualifiedName.of("at_timezone"))
.addArgument(valueType, value)
.addArgument(timeZoneType, timeZone)
.build();
}

private Type getType(Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,63 @@
package io.prestosql.sql.planner;

import com.google.common.collect.ImmutableList;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.LambdaExpression;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.sql.tree.TryExpression;
import io.prestosql.type.FunctionType;

import java.util.Map;

import static io.prestosql.operator.scalar.TryFunction.NAME;

public class DesugarTryExpressionRewriter
{
private DesugarTryExpressionRewriter() {}

public static Expression rewrite(Expression expression)
public static Expression rewrite(Expression expression, Metadata metadata, TypeAnalyzer typeAnalyzer, Session session, SymbolAllocator symbolAllocator)
{
return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression);
if (expression instanceof SymbolReference) {
return expression;
}

Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(
session,
symbolAllocator.getTypes(),
expression);

return ExpressionTreeRewriter.rewriteWith(new Visitor(metadata, expressionTypes), expression);
}

private static class Visitor
extends ExpressionRewriter<Void>
{
private final Metadata metadata;
private final Map<NodeRef<Expression>, Type> expressionTypes;

public Visitor(Metadata metadata, Map<NodeRef<Expression>, Type> expressionTypes)
{
this.metadata = metadata;
this.expressionTypes = expressionTypes;
}

@Override
public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Type type = expressionTypes.get(NodeRef.of(node));
Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context);

return new FunctionCall(
QualifiedName.of("$internal$try"),
ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression)));
return new FunctionCallBuilder(metadata)
.setName(QualifiedName.of(NAME))
.addArgument(new FunctionType(ImmutableList.of(), type), new LambdaExpression(ImmutableList.of(), expression))
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ private static class Visitor
private Visitor(Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde());
this.literalEncoder = new LiteralEncoder(metadata);
this.session = requireNonNull(session, "session is null");
this.types = requireNonNull(types, "types is null");
this.functionInvoker = new InterpretedFunctionInvoker(metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@
import static io.prestosql.spi.type.TypeSignature.parseTypeSignature;
import static io.prestosql.spi.type.TypeUtils.readNativeValue;
import static io.prestosql.spi.type.TypeUtils.writeNativeValue;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static io.prestosql.spi.type.VarcharType.createVarcharType;
import static io.prestosql.sql.analyzer.ConstantExpressionVerifier.verifyExpressionIsConstant;
import static io.prestosql.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.prestosql.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter;
import static io.prestosql.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.prestosql.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression;
import static io.prestosql.type.JsonType.JSON;
import static io.prestosql.type.LikeFunctions.isLikePattern;
import static io.prestosql.type.LikeFunctions.unescapeLiteralLikePattern;
import static io.prestosql.util.Failures.checkCondition;
Expand Down Expand Up @@ -211,11 +213,16 @@ private static Object evaluateConstantExpression(
analyzer.analyze(rewrite, Scope.create());

// remove syntax sugar
rewrite = DesugarAtTimeZoneRewriter.rewrite(rewrite, analyzer.getExpressionTypes());
rewrite = DesugarAtTimeZoneRewriter.rewrite(rewrite, analyzer.getExpressionTypes(), metadata);

// The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis
// to re-analyze coercions that might be necessary
analyzer = createConstantAnalyzer(metadata, session, parameters, WarningCollector.NOOP);
analyzer.analyze(rewrite, Scope.create());

// expressionInterpreter/optimizer only understands a subset of expression types
// TODO: remove this when the new expression tree is implemented
Expression canonicalized = canonicalizeExpression(rewrite);
Expression canonicalized = canonicalizeExpression(rewrite, analyzer.getExpressionTypes(), metadata);

// The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis
// to re-analyze coercions that might be necessary
Expand All @@ -232,7 +239,7 @@ private ExpressionInterpreter(Expression expression, Metadata metadata, Session
{
this.expression = requireNonNull(expression, "expression is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.literalEncoder = new LiteralEncoder(metadata.getBlockEncodingSerde());
this.literalEncoder = new LiteralEncoder(metadata);
this.session = requireNonNull(session, "session is null").toConnectorSession();
this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null"));
verify((expressionTypes.containsKey(NodeRef.of(expression))));
Expand Down Expand Up @@ -431,7 +438,7 @@ private Object processWithExceptionHandling(Expression expression, Object contex
// HACK
// Certain operations like 0 / 0 or likeExpression may throw exceptions.
// Wrap them a FunctionCall that will throw the exception if the expression is actually executed
return createFailureFunction(e, type(expression));
return createFailureFunction(e, type(expression), ExpressionInterpreter.this.metadata);
}
}

Expand Down Expand Up @@ -900,7 +907,14 @@ protected Object visitFunctionCall(FunctionCall node, Object context)

// do not optimize non-deterministic functions
if (optimize && (!function.isDeterministic() || hasUnresolvedValue(argumentValues) || node.getName().equals(QualifiedName.of("fail")))) {
return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), toExpressions(argumentValues, argumentTypes));
verify(!node.isDistinct(), "window does not support distinct");
verify(!node.getOrderBy().isPresent(), "window does not support order by");
verify(!node.getFilter().isPresent(), "window does not support filter");
return new FunctionCallBuilder(metadata)
.setName(node.getName())
.setWindow(node.getWindow())
.setArguments(argumentTypes, toExpressions(argumentValues, argumentTypes))
.build();
}
return functionInvoker.invoke(functionSignature, session, argumentValues);
}
Expand Down Expand Up @@ -1112,7 +1126,12 @@ protected Object visitArrayConstructor(ArrayConstructor node, Object context)
Object value = process(expression, context);
if (value instanceof Expression) {
checkCondition(node.getValues().size() <= 254, NOT_SUPPORTED, "Too many arguments for array constructor");
return visitFunctionCall(new FunctionCall(QualifiedName.of(ArrayConstructor.ARRAY_CONSTRUCTOR), node.getValues()), context);
return visitFunctionCall(
new FunctionCallBuilder(metadata)
.setName(QualifiedName.of(ArrayConstructor.ARRAY_CONSTRUCTOR))
.setArguments(types(node.getValues()), node.getValues())
.build(),
context);
}
writeNativeValue(elementType, arrayBlockBuilder, value);
}
Expand All @@ -1123,13 +1142,13 @@ protected Object visitArrayConstructor(ArrayConstructor node, Object context)
@Override
protected Object visitCurrentUser(CurrentUser node, Object context)
{
return visitFunctionCall(DesugarCurrentUser.getCall(node), context);
return visitFunctionCall(DesugarCurrentUser.getCall(node, metadata), context);
}

@Override
protected Object visitCurrentPath(CurrentPath node, Object context)
{
return visitFunctionCall(DesugarCurrentPath.getCall(node), context);
return visitFunctionCall(DesugarCurrentPath.getCall(node, metadata), context);
}

@Override
Expand Down Expand Up @@ -1213,9 +1232,17 @@ protected Object visitNode(Node node, Object context)
throw new UnsupportedOperationException("Evaluator visitor can only handle Expression nodes");
}

private List<Type> types(Expression... types)
private List<Type> types(Expression... expressions)
{
return Stream.of(types)
return Stream.of(expressions)
.map(NodeRef::of)
.map(expressionTypes::get)
.collect(toImmutableList());
}

private List<Type> types(List<Expression> expressions)
{
return expressions.stream()
.map(NodeRef::of)
.map(expressionTypes::get)
.collect(toImmutableList());
Expand Down Expand Up @@ -1271,13 +1298,19 @@ public int getPosition(int channel)
}
}

private static Expression createFailureFunction(RuntimeException exception, Type type)
private static Expression createFailureFunction(RuntimeException exception, Type type, Metadata metadata)
{
requireNonNull(exception, "Exception is null");

String failureInfo = JsonCodec.jsonCodec(FailureInfo.class).toJson(Failures.toFailure(exception).toFailureInfo());
FunctionCall jsonParse = new FunctionCall(QualifiedName.of("json_parse"), ImmutableList.of(new StringLiteral(failureInfo)));
FunctionCall failureFunction = new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(jsonParse));
FunctionCall jsonParse = new FunctionCallBuilder(metadata)
.setName(QualifiedName.of("json_parse"))
.addArgument(VARCHAR, new StringLiteral(failureInfo))
.build();
FunctionCall failureFunction = new FunctionCallBuilder(metadata)
.setName(QualifiedName.of("fail"))
.addArgument(JSON, jsonParse)
.build();

return new Cast(failureFunction, type.getTypeSignature().toString());
}
Expand Down
Loading

0 comments on commit 9f78af6

Please sign in to comment.