Skip to content

Commit

Permalink
feat: support reading of Substrait plans containing window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua committed Aug 10, 2023
1 parent 418ad58 commit 9c65d6b
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.expression.proto;

import io.substrait.expression.AbstractExpressionVisitor;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -420,7 +421,8 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
expr.partitionBy().stream()
.map(e -> e.accept(this))
.collect(java.util.stream.Collectors.toList());
var builder = Expression.WindowFunction.newBuilder();
var outputType = expr.getType().accept(typeProtoConverter);
var builder = Expression.WindowFunction.newBuilder().setOutputType(outputType);
if (expr.hasNormalAggregateFunction()) {
var aggMeasureFunc = expr.aggregateFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration());
Expand All @@ -429,18 +431,22 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
aggMeasureFunc.arguments().stream()
.map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
var ordinal = aggMeasureFunc.aggregationPhase().ordinal();
builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
builder
.setFunctionReference(funcReference)
.setPhase(aggMeasureFunc.aggregationPhase().toProto())
.addAllArguments(args);
} else {
var windowFunc = expr.windowFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration());
var ordinal = windowFunc.aggregationPhase().ordinal();
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
windowFunc.arguments().stream()
.map(a -> a.accept(windowFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
builder
.setFunctionReference(funcReference)
.setPhase(windowFunc.aggregationPhase().toProto())
.addAllArguments(args);
}
var sortFields =
expr.orderBy().stream()
Expand All @@ -465,57 +471,68 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
.build();
}

private static class LiteralToWindowBoundOffset
extends AbstractExpressionVisitor<Long, RuntimeException> {

@Override
public Long visitFallback(io.substrait.expression.Expression expr) {
throw new RuntimeException(
String.format("Expected positive integer for Window Bound offset, received: %s", expr));
}

private static long offsetIsPositive(long offset) {
if (offset >= 1) {
return offset;
}
throw new RuntimeException(
String.format("Expected positive offset for Window Bound offset, recieved: %d", offset));
}

@Override
public Long visit(io.substrait.expression.Expression.I8Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I16Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I32Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I64Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}
}

private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) {
var boundedKind = windowBound.boundedKind();
Expression.WindowFunction.Bound bound = null;
switch (boundedKind) {
case CURRENT_ROW -> bound =
Expression.WindowFunction.Bound.newBuilder()
.setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance())
.build();
return switch (boundedKind) {
case CURRENT_ROW -> Expression.WindowFunction.Bound.newBuilder()
.setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance())
.build();
case BOUNDED -> {
WindowBound.BoundedWindowBound boundedWindowBound =
(WindowBound.BoundedWindowBound) windowBound;
var offset = boundedWindowBound.offset();
boolean isPreceding = boundedWindowBound.direction() == WindowBound.Direction.PRECEDING;
io.substrait.expression.Expression.I32Literal offsetLiteral =
(io.substrait.expression.Expression.I32Literal) offset;
var offsetVal = offsetLiteral.value();
var boundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance();
if (isPreceding) {
var offsetProto =
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offsetVal).build();
bound = Expression.WindowFunction.Bound.newBuilder().setPreceding(offsetProto).build();
} else {
var offsetProto =
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offsetVal).build();
bound = Expression.WindowFunction.Bound.newBuilder().setFollowing(offsetProto).build();
}
}
case UNBOUNDED -> {
WindowBound.UnboundedWindowBound unboundedWindowBound =
(WindowBound.UnboundedWindowBound) windowBound;
boolean isPreceding = unboundedWindowBound.direction() == WindowBound.Direction.PRECEDING;
var unboundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance();
if (isPreceding) {
var preceding = Expression.WindowFunction.Bound.Preceding.newBuilder().build();
bound =
Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(unboundedProto)
.setPreceding(preceding)
.build();
} else {
var following = Expression.WindowFunction.Bound.Following.newBuilder().build();
bound =
Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(unboundedProto)
.setFollowing(following)
.build();
}
var offset = boundedWindowBound.offset().accept(new LiteralToWindowBoundOffset());
yield switch (boundedWindowBound.direction()) {
case PRECEDING -> Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offset))
.build();
case FOLLOWING -> Expression.WindowFunction.Bound.newBuilder()
.setFollowing(
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offset))
.build();
};
}
default -> throw new RuntimeException(
String.format("Unexpected Expression.WindowFunction.Bound enum:%s", boundedKind));
}
return bound;
case UNBOUNDED -> Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance())
.build();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.WindowBound;
import io.substrait.expression.WindowFunctionInvocation;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.type.Type;
Expand Down Expand Up @@ -109,7 +112,73 @@ public Expression from(io.substrait.proto.Expression expr) {
yield ImmutableExpression.ScalarFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(expr.getScalarFunction().getOutputType()))
.outputType(protoTypeConverter.from(scalarFunction.getOutputType()))
.build();
}
case WINDOW_FUNCTION -> {
var windowFunction = expr.getWindowFunction();
var functionReference = windowFunction.getFunctionReference();
SimpleExtension.WindowFunctionVariant functionVariant;
try {
functionVariant = lookup.getWindowFunction(functionReference, extensions);
} catch (RuntimeException e) {
// TODO: Ideally we shouldn't need to catch a RuntimeException to be able to attempt our
// second lookup
var aggFunctionVariant = lookup.getAggregateFunction(functionReference, extensions);
functionVariant =
ImmutableSimpleExtension.WindowFunctionVariant.builder()
// Sets all fields declared in the Function interface
.from(aggFunctionVariant)
// Set WindowFunctionVariant fields
.decomposability(aggFunctionVariant.decomposability())
.intermediate(aggFunctionVariant.intermediate())
// Aggregate Functions used in Windows have WindowType Streaming
.windowType(SimpleExtension.WindowType.STREAMING)
.build();
}
final SimpleExtension.WindowFunctionVariant declaration = functionVariant;

var pF = new FunctionArg.ProtoFrom(this, protoTypeConverter);
var args =
IntStream.range(0, windowFunction.getArgumentsCount())
.mapToObj(i -> pF.convert(declaration, i, windowFunction.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
var partitionExprs =
windowFunction.getPartitionsList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList());
var sortFields =
windowFunction.getSortsList().stream()
.map(
s ->
Expression.SortField.builder()
.direction(Expression.SortDirection.fromProto(s.getDirection()))
.expr(from(s.getExpr()))
.build())
.collect(java.util.stream.Collectors.toList());
var wfi =
WindowFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(windowFunction.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase()))
.addAllSort(sortFields)
.invocation(
Expression.AggregationInvocation.fromProto(windowFunction.getInvocation()))
.build();

WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound());
WindowBound upperBound = toUpperBound(windowFunction.getUpperBound());

var wf = ImmutableExpression.WindowFunction.builder().function(wfi).build();
yield Expression.Window.builder()
.windowFunction(wf)
.hasNormalAggregateFunction(false)
.type(protoTypeConverter.from(windowFunction.getOutputType()))
.partitionBy(partitionExprs)
.orderBy(sortFields)
.lowerBound(lowerBound)
.upperBound(upperBound)
.build();
}
case IF_THEN -> {
Expand Down Expand Up @@ -199,13 +268,51 @@ public Expression from(io.substrait.proto.Expression expr) {
}
}

// TODO window, enum.
case WINDOW_FUNCTION, ENUM -> throw new UnsupportedOperationException(
// TODO enum.
case ENUM -> throw new UnsupportedOperationException(
"Unsupported type: " + expr.getRexTypeCase());
default -> throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase());
};
}

private WindowBound toLowerBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.build());
}

private WindowBound toUpperBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.build());
}

private WindowBound toWindowBound(
io.substrait.proto.Expression.WindowFunction.Bound bound, WindowBound defaultBound) {
return switch (bound.getKindCase()) {
case PRECEDING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.offset(
Expression.Literal.I64Literal.builder()
.value(bound.getPreceding().getOffset())
.build())
.build();
case FOLLOWING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.offset(
Expression.Literal.I64Literal.builder()
.value(bound.getFollowing().getOffset())
.build())
.build();
case CURRENT_ROW -> WindowBound.CURRENT_ROW;
case UNBOUNDED, KIND_NOT_SET -> defaultBound;
};
}

public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
return switch (literal.getLiteralTypeCase()) {
case BOOLEAN -> ExpressionCreator.bool(literal.getNullable(), literal.getBoolean());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ enum Decomposability {
MANY
}

enum WindowType {
public enum WindowType {
PARTITION,
STREAMING
}
Expand Down
Loading

0 comments on commit 9c65d6b

Please sign in to comment.