Skip to content

Commit

Permalink
feat(spark): add Window support
Browse files Browse the repository at this point in the history
To support the OVER clause in SQL

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Oct 25, 2024
1 parent b8ccd8b commit 86a0548
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 23 deletions.
13 changes: 13 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(window: ConsistentPartitionWindow): String = {
withBuilder(window, 10)(
builder => {
builder
.append("functions=")
.append(window.getWindowFunctions)
.append("partitions=")
.append(window.getPartitionExpressions)
.append("sorts=")
.append(window.getSorts)
})
}

override def visit(localFiles: LocalFiles): String = {
withBuilder(localFiles, 10)(
builder => {
Expand Down
6 changes: 5 additions & 1 deletion spark/src/main/scala/io/substrait/spark/SparkExtension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package io.substrait.spark

import io.substrait.spark.expression.ToAggregateFunction
import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction}

import io.substrait.extension.SimpleExtension

Expand All @@ -43,4 +43,8 @@ object SparkExtension {

val toAggregateFunction: ToAggregateFunction = ToAggregateFunction(
JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions()))

val toWindowFunction: ToWindowFunction = ToWindowFunction(
JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.windowFunctions())
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.substrait.spark.ToSubstraitType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -238,7 +238,6 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
val parent: FunctionConverter[F, T]) {

def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = {

val opTypes = operands.map(_.getType)
val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable)
val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE))
Expand All @@ -250,17 +249,23 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
.map(name + ":" + _)
.find(k => directMap.contains(k))

if (directMatchKey.isDefined) {
if (operands.isEmpty) {
val variant = directMap(name + ":")
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
Option(parent.generateBinding(expression, variant, operands, outputType))
} else if (directMatchKey.isDefined) {
val variant = directMap(directMatchKey.get)
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
val funcArgs: Seq[FunctionArg] = operands
Option(parent.generateBinding(expression, variant, funcArgs, outputType))
} else if (singularInputType.isDefined) {
val types = expression match {
case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType)
case other => other.children.map(_.dataType)
val children = expression match {
case agg: AggregateExpression => agg.aggregateFunction.children
case win: WindowExpression => win.windowFunction.children
case other => other.children
}
val nullable = expression.children.exists(e => e.nullable)
val types = children.map(_.dataType)
val nullable = children.exists(e => e.nullable)
FunctionFinder
.leastRestrictive(types)
.flatMap(
Expand Down Expand Up @@ -298,6 +303,4 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
}
})
}

def allowedArgCount(count: Int): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,22 @@ class FunctionMappings {
s[HyperLogLogPlusPlus]("approx_count_distinct")
)

val WINDOW_SIGS: Seq[Sig] = Seq(
s[RowNumber]("row_number"),
s[Rank]("rank"),
s[DenseRank]("dense_rank"),
s[PercentRank]("percent_rank"),
s[CumeDist]("cume_dist"),
s[NTile]("ntile"),
s[Lead]("lead"),
s[Lag]("lag"),
s[NthValue]("nth_value")
)

lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap
lazy val aggregate_functions_map: Map[Class[_], Sig] =
AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap
lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap
}

object FunctionMappings extends FunctionMappings
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ abstract class ToAggregateFunction(functions: Seq[SimpleExtension.AggregateFunct
expression: AggregateExpression,
operands: Seq[SExpression]): Option[AggregateFunctionInvocation] = {
Option(signatures.get(expression.aggregateFunction.getClass))
.filter(m => m.allowedArgCount(2))
.flatMap(m => m.attemptMatch(expression, operands))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar

def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = {
Option(signatures.get(expression.getClass))
.filter(m => m.allowedArgCount(2))
.flatMap(m => m.attemptMatch(expression, operands))
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.substrait.spark.expression

import io.substrait.spark.expression.ToWindowFunction.fromSpark

import org.apache.spark.sql.catalyst.expressions.{CurrentRow, Expression, FrameType, Literal, OffsetWindowFunction, RangeFrame, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnspecifiedFrame, WindowExpression, WindowFrame, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.types.{IntegerType, LongType}

import io.substrait.`type`.Type
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg, WindowBound}
import io.substrait.expression.Expression.WindowBoundsType
import io.substrait.expression.WindowBound.{CURRENT_ROW, UNBOUNDED, WindowBoundVisitor}
import io.substrait.extension.SimpleExtension
import io.substrait.relation.ConsistentPartitionWindow.WindowRelFunctionInvocation

import scala.collection.JavaConverters

abstract class ToWindowFunction(functions: Seq[SimpleExtension.WindowFunctionVariant])
extends FunctionConverter[SimpleExtension.WindowFunctionVariant, WindowRelFunctionInvocation](
functions) {

override def generateBinding(
sparkExp: Expression,
function: SimpleExtension.WindowFunctionVariant,
arguments: Seq[FunctionArg],
outputType: Type): WindowRelFunctionInvocation = {

val (frameType, lower, upper) = sparkExp match {
case WindowExpression(
_,
WindowSpecDefinition(_, _, SpecifiedWindowFrame(frameType, lower, upper))) =>
(fromSpark(frameType), fromSpark(lower), fromSpark(upper))
case WindowExpression(_, WindowSpecDefinition(_, orderSpec, UnspecifiedFrame)) =>
if (orderSpec.isEmpty) {
(WindowBoundsType.ROWS, UNBOUNDED, UNBOUNDED)
} else {
(WindowBoundsType.RANGE, UNBOUNDED, CURRENT_ROW)
}

case _ => throw new UnsupportedOperationException(s"Unsupported window expression: $sparkExp")
}

ExpressionCreator.windowRelFunction(
function,
outputType,
SExpression.AggregationPhase.INITIAL_TO_RESULT, // use defaults...
SExpression.AggregationInvocation.ALL, // Spark doesn't define these
frameType,
lower,
upper,
JavaConverters.asJavaIterable(arguments)
)
}

def convert(
expression: WindowExpression,
operands: Seq[SExpression]): Option[WindowRelFunctionInvocation] = {
val cls = expression.windowFunction match {
case agg: AggregateExpression => agg.aggregateFunction.getClass
case other => other.getClass
}

Option(signatures.get(cls))
.flatMap(m => m.attemptMatch(expression, operands))
}

def apply(
expression: WindowExpression,
operands: Seq[SExpression]): WindowRelFunctionInvocation = {
convert(expression, operands).getOrElse(throw new UnsupportedOperationException(
s"Unable to find binding for call ${expression.windowFunction} -- $operands -- $expression"))
}
}

object ToWindowFunction {
def fromSpark(frameType: FrameType): WindowBoundsType = frameType match {
case RowFrame => WindowBoundsType.ROWS
case RangeFrame => WindowBoundsType.RANGE
case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.")
}

def fromSpark(bound: Expression): WindowBound = bound match {
case UnboundedPreceding => WindowBound.UNBOUNDED
case UnboundedFollowing => WindowBound.UNBOUNDED
case CurrentRow => WindowBound.CURRENT_ROW
case e: Literal =>
e.dataType match {
case IntegerType | LongType =>
val offset = e.eval().asInstanceOf[Int]
if (offset < 0) WindowBound.Preceding.of(-offset)
else if (offset == 0) WindowBound.CURRENT_ROW
else WindowBound.Following.of(offset)
}
case _ => throw new UnsupportedOperationException(s"Unexpected bound: $bound")
}

def toSparkFrame(
boundsType: WindowBoundsType,
lowerBound: WindowBound,
upperBound: WindowBound): WindowFrame = {
val frameType = boundsType match {
case WindowBoundsType.ROWS => RowFrame
case WindowBoundsType.RANGE => RangeFrame
case WindowBoundsType.UNSPECIFIED => return UnspecifiedFrame
}
SpecifiedWindowFrame(
frameType,
toSparkBound(lowerBound, isLower = true),
toSparkBound(upperBound, isLower = false))
}

private def toSparkBound(bound: WindowBound, isLower: Boolean): Expression = {
bound.accept(new WindowBoundVisitor[Expression, Exception] {

override def visit(preceding: WindowBound.Preceding): Expression =
Literal(-preceding.offset().intValue())

override def visit(following: WindowBound.Following): Expression =
Literal(following.offset().intValue())

override def visit(currentRow: WindowBound.CurrentRow): Expression = CurrentRow

override def visit(unbounded: WindowBound.Unbounded): Expression =
if (isLower) UnboundedPreceding else UnboundedFollowing
})
}

def apply(functions: Seq[SimpleExtension.WindowFunctionVariant]): ToWindowFunction = {
new ToWindowFunction(functions) {
override def getSigs: Seq[Sig] =
FunctionMappings.WINDOW_SIGS ++ FunctionMappings.AGGREGATE_SIGS
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,56 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
}

override def visit(window: relation.ConsistentPartitionWindow): LogicalPlan = {
val child = window.getInput.accept(this)
withChild(child) {
val partitions = window.getPartitionExpressions.asScala
.map(expr => expr.accept(expressionConverter))
val sortOrders = window.getSorts.asScala.map(toSortOrder)
val windowExpressions = window.getWindowFunctions.asScala
.map(
func => {
val arguments = func.arguments().asScala.zipWithIndex.map {
case (arg, i) =>
arg.accept(func.declaration(), i, expressionConverter)
}
val windowFunction = SparkExtension.toWindowFunction
.getSparkExpressionFromSubstraitFunc(func.declaration.key, func.outputType)
.map(sig => sig.makeCall(arguments))
.map {
case win: WindowFunction => win
case agg: AggregateFunction =>
AggregateExpression(
agg,
ToAggregateFunction.toSpark(func.aggregationPhase()),
ToAggregateFunction.toSpark(func.invocation()),
None)
}
.getOrElse({
val msg = String.format(
"Unable to convert Window function %s(%s).",
func.declaration.name,
func.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
val frame =
ToWindowFunction.toSparkFrame(func.boundsType(), func.lowerBound(), func.upperBound())
val spec = WindowSpecDefinition(partitions, sortOrders, frame)
WindowExpression(windowFunction, spec)
})
.map(toNamedExpression)
Window(windowExpressions, partitions, sortOrders, child)
}
}

override def visit(join: relation.Join): LogicalPlan = {
val left = join.getLeft.accept(this)
val right = join.getRight.accept(this)
Expand Down Expand Up @@ -162,6 +212,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
SortOrder(expression, direction, nullOrdering, Seq.empty)
}

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
Expand All @@ -180,6 +231,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
Offset(toLiteral(offset), child)
}
}

override def visit(sort: relation.Sort): LogicalPlan = {
val child = sort.getInput.accept(this)
withChild(child) {
Expand Down
Loading

0 comments on commit 86a0548

Please sign in to comment.