From 86a0548739bdac2bc4a82cc21ff774935f9880e5 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Fri, 25 Oct 2024 09:28:02 +0100 Subject: [PATCH] feat(spark): add Window support To support the OVER clause in SQL Signed-off-by: Andrew Coleman --- .../substrait/debug/RelToVerboseString.scala | 13 ++ .../io/substrait/spark/SparkExtension.scala | 6 +- .../spark/expression/FunctionConverter.scala | 21 +-- .../spark/expression/FunctionMappings.scala | 13 ++ .../expression/ToAggregateFunction.scala | 1 - .../spark/expression/ToScalarFunction.scala | 1 - .../spark/expression/ToWindowFunction.scala | 151 ++++++++++++++++++ .../spark/logical/ToLogicalPlan.scala | 52 ++++++ .../spark/logical/ToSubstraitRel.scala | 33 +++- .../main/scala/io/substrait/utils/Util.scala | 1 + .../scala/io/substrait/spark/TPCDSPlan.scala | 20 +-- .../scala/io/substrait/spark/WindowPlan.scala | 79 +++++++++ 12 files changed, 368 insertions(+), 23 deletions(-) create mode 100644 spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala create mode 100644 spark/src/test/scala/io/substrait/spark/WindowPlan.scala diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 0ba749b9e..79b34462f 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -152,6 +152,19 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(window: ConsistentPartitionWindow): String = { + withBuilder(window, 10)( + builder => { + builder + .append("functions=") + .append(window.getWindowFunctions) + .append("partitions=") + .append(window.getPartitionExpressions) + .append("sorts=") + .append(window.getSorts) + }) + } + override def visit(localFiles: LocalFiles): String = { withBuilder(localFiles, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index d61a06d3e..53b5bfaaf 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -16,7 +16,7 @@ */ package io.substrait.spark -import io.substrait.spark.expression.ToAggregateFunction +import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction} import io.substrait.extension.SimpleExtension @@ -43,4 +43,8 @@ object SparkExtension { val toAggregateFunction: ToAggregateFunction = ToAggregateFunction( JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions())) + + val toWindowFunction: ToWindowFunction = ToWindowFunction( + JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.windowFunctions()) + ) } diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala index e32e5e583..5c0f72692 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala @@ -21,7 +21,7 @@ import io.substrait.spark.ToSubstraitType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.types.DataType @@ -238,7 +238,6 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( val parent: FunctionConverter[F, T]) { def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = { - val opTypes = operands.map(_.getType) val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable) val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE)) @@ -250,17 +249,23 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( .map(name + ":" + _) .find(k => directMap.contains(k)) - if (directMatchKey.isDefined) { + if (operands.isEmpty) { + val variant = directMap(name + ":") + variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) + Option(parent.generateBinding(expression, variant, operands, outputType)) + } else if (directMatchKey.isDefined) { val variant = directMap(directMatchKey.get) variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) val funcArgs: Seq[FunctionArg] = operands Option(parent.generateBinding(expression, variant, funcArgs, outputType)) } else if (singularInputType.isDefined) { - val types = expression match { - case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType) - case other => other.children.map(_.dataType) + val children = expression match { + case agg: AggregateExpression => agg.aggregateFunction.children + case win: WindowExpression => win.windowFunction.children + case other => other.children } - val nullable = expression.children.exists(e => e.nullable) + val types = children.map(_.dataType) + val nullable = children.exists(e => e.nullable) FunctionFinder .leastRestrictive(types) .flatMap( @@ -298,6 +303,4 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( } }) } - - def allowedArgCount(count: Int): Boolean = true } diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index ac4fed93e..d37274822 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -80,9 +80,22 @@ class FunctionMappings { s[HyperLogLogPlusPlus]("approx_count_distinct") ) + val WINDOW_SIGS: Seq[Sig] = Seq( + s[RowNumber]("row_number"), + s[Rank]("rank"), + s[DenseRank]("dense_rank"), + s[PercentRank]("percent_rank"), + s[CumeDist]("cume_dist"), + s[NTile]("ntile"), + s[Lead]("lead"), + s[Lag]("lag"), + s[NthValue]("nth_value") + ) + lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap lazy val aggregate_functions_map: Map[Class[_], Sig] = AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap + lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap } object FunctionMappings extends FunctionMappings diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala index 0c5b50c6c..9e959e47e 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala @@ -53,7 +53,6 @@ abstract class ToAggregateFunction(functions: Seq[SimpleExtension.AggregateFunct expression: AggregateExpression, operands: Seq[SExpression]): Option[AggregateFunctionInvocation] = { Option(signatures.get(expression.aggregateFunction.getClass)) - .filter(m => m.allowedArgCount(2)) .flatMap(m => m.attemptMatch(expression, operands)) } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala index cd23611ec..dac4873c3 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala @@ -42,7 +42,6 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = { Option(signatures.get(expression.getClass)) - .filter(m => m.allowedArgCount(2)) .flatMap(m => m.attemptMatch(expression, operands)) } } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala new file mode 100644 index 000000000..1c6708b20 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.expression.ToWindowFunction.fromSpark + +import org.apache.spark.sql.catalyst.expressions.{CurrentRow, Expression, FrameType, Literal, OffsetWindowFunction, RangeFrame, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnspecifiedFrame, WindowExpression, WindowFrame, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.types.{IntegerType, LongType} + +import io.substrait.`type`.Type +import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg, WindowBound} +import io.substrait.expression.Expression.WindowBoundsType +import io.substrait.expression.WindowBound.{CURRENT_ROW, UNBOUNDED, WindowBoundVisitor} +import io.substrait.extension.SimpleExtension +import io.substrait.relation.ConsistentPartitionWindow.WindowRelFunctionInvocation + +import scala.collection.JavaConverters + +abstract class ToWindowFunction(functions: Seq[SimpleExtension.WindowFunctionVariant]) + extends FunctionConverter[SimpleExtension.WindowFunctionVariant, WindowRelFunctionInvocation]( + functions) { + + override def generateBinding( + sparkExp: Expression, + function: SimpleExtension.WindowFunctionVariant, + arguments: Seq[FunctionArg], + outputType: Type): WindowRelFunctionInvocation = { + + val (frameType, lower, upper) = sparkExp match { + case WindowExpression( + _, + WindowSpecDefinition(_, _, SpecifiedWindowFrame(frameType, lower, upper))) => + (fromSpark(frameType), fromSpark(lower), fromSpark(upper)) + case WindowExpression(_, WindowSpecDefinition(_, orderSpec, UnspecifiedFrame)) => + if (orderSpec.isEmpty) { + (WindowBoundsType.ROWS, UNBOUNDED, UNBOUNDED) + } else { + (WindowBoundsType.RANGE, UNBOUNDED, CURRENT_ROW) + } + + case _ => throw new UnsupportedOperationException(s"Unsupported window expression: $sparkExp") + } + + ExpressionCreator.windowRelFunction( + function, + outputType, + SExpression.AggregationPhase.INITIAL_TO_RESULT, // use defaults... + SExpression.AggregationInvocation.ALL, // Spark doesn't define these + frameType, + lower, + upper, + JavaConverters.asJavaIterable(arguments) + ) + } + + def convert( + expression: WindowExpression, + operands: Seq[SExpression]): Option[WindowRelFunctionInvocation] = { + val cls = expression.windowFunction match { + case agg: AggregateExpression => agg.aggregateFunction.getClass + case other => other.getClass + } + + Option(signatures.get(cls)) + .flatMap(m => m.attemptMatch(expression, operands)) + } + + def apply( + expression: WindowExpression, + operands: Seq[SExpression]): WindowRelFunctionInvocation = { + convert(expression, operands).getOrElse(throw new UnsupportedOperationException( + s"Unable to find binding for call ${expression.windowFunction} -- $operands -- $expression")) + } +} + +object ToWindowFunction { + def fromSpark(frameType: FrameType): WindowBoundsType = frameType match { + case RowFrame => WindowBoundsType.ROWS + case RangeFrame => WindowBoundsType.RANGE + case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.") + } + + def fromSpark(bound: Expression): WindowBound = bound match { + case UnboundedPreceding => WindowBound.UNBOUNDED + case UnboundedFollowing => WindowBound.UNBOUNDED + case CurrentRow => WindowBound.CURRENT_ROW + case e: Literal => + e.dataType match { + case IntegerType | LongType => + val offset = e.eval().asInstanceOf[Int] + if (offset < 0) WindowBound.Preceding.of(-offset) + else if (offset == 0) WindowBound.CURRENT_ROW + else WindowBound.Following.of(offset) + } + case _ => throw new UnsupportedOperationException(s"Unexpected bound: $bound") + } + + def toSparkFrame( + boundsType: WindowBoundsType, + lowerBound: WindowBound, + upperBound: WindowBound): WindowFrame = { + val frameType = boundsType match { + case WindowBoundsType.ROWS => RowFrame + case WindowBoundsType.RANGE => RangeFrame + case WindowBoundsType.UNSPECIFIED => return UnspecifiedFrame + } + SpecifiedWindowFrame( + frameType, + toSparkBound(lowerBound, isLower = true), + toSparkBound(upperBound, isLower = false)) + } + + private def toSparkBound(bound: WindowBound, isLower: Boolean): Expression = { + bound.accept(new WindowBoundVisitor[Expression, Exception] { + + override def visit(preceding: WindowBound.Preceding): Expression = + Literal(-preceding.offset().intValue()) + + override def visit(following: WindowBound.Following): Expression = + Literal(following.offset().intValue()) + + override def visit(currentRow: WindowBound.CurrentRow): Expression = CurrentRow + + override def visit(unbounded: WindowBound.Unbounded): Expression = + if (isLower) UnboundedPreceding else UnboundedFollowing + }) + } + + def apply(functions: Seq[SimpleExtension.WindowFunctionVariant]): ToWindowFunction = { + new ToWindowFunction(functions) { + override def getSigs: Seq[Sig] = + FunctionMappings.WINDOW_SIGS ++ FunctionMappings.AGGREGATE_SIGS + } + } + +} diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 7525ccc5f..daec2a5ed 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -119,6 +119,56 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } } + override def visit(window: relation.ConsistentPartitionWindow): LogicalPlan = { + val child = window.getInput.accept(this) + withChild(child) { + val partitions = window.getPartitionExpressions.asScala + .map(expr => expr.accept(expressionConverter)) + val sortOrders = window.getSorts.asScala.map(toSortOrder) + val windowExpressions = window.getWindowFunctions.asScala + .map( + func => { + val arguments = func.arguments().asScala.zipWithIndex.map { + case (arg, i) => + arg.accept(func.declaration(), i, expressionConverter) + } + val windowFunction = SparkExtension.toWindowFunction + .getSparkExpressionFromSubstraitFunc(func.declaration.key, func.outputType) + .map(sig => sig.makeCall(arguments)) + .map { + case win: WindowFunction => win + case agg: AggregateFunction => + AggregateExpression( + agg, + ToAggregateFunction.toSpark(func.aggregationPhase()), + ToAggregateFunction.toSpark(func.invocation()), + None) + } + .getOrElse({ + val msg = String.format( + "Unable to convert Window function %s(%s).", + func.declaration.name, + func.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + val frame = + ToWindowFunction.toSparkFrame(func.boundsType(), func.lowerBound(), func.upperBound()) + val spec = WindowSpecDefinition(partitions, sortOrders, frame) + WindowExpression(windowFunction, spec) + }) + .map(toNamedExpression) + Window(windowExpressions, partitions, sortOrders, child) + } + } + override def visit(join: relation.Join): LogicalPlan = { val left = join.getLeft.accept(this) val right = join.getRight.accept(this) @@ -162,6 +212,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } SortOrder(expression, direction, nullOrdering, Seq.empty) } + override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) val limit = fetch.getCount.getAsLong.intValue() @@ -180,6 +231,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] Offset(toLiteral(offset), child) } } + override def visit(sort: relation.Sort): LogicalPlan = { val child = sort.getInput.accept(this) withChild(child) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index c27c98a91..b93eaecbe 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NullType, StructType} import ToSubstraitType.toNamedStruct import io.substrait.{proto, relation} @@ -172,6 +172,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { .build() } + private def fromWindowCall( + expression: WindowExpression, + output: Seq[Attribute]): relation.ConsistentPartitionWindow.WindowRelFunctionInvocation = { + val children = expression.windowFunction match { + case agg: AggregateExpression => agg.aggregateFunction.children + case _: RankLike => Seq.empty + case other => other.children + } + val substraitExps = children.filter(_ != Literal(null, NullType)).map(toExpression(output)) + SparkExtension.toWindowFunction.apply(expression, substraitExps) + } + + override def visitWindow(window: Window): relation.Rel = { + val windowExpressions = window.windowExpressions.map { + case w: WindowExpression => fromWindowCall(w, window.child.output) + case a: Alias if a.child.isInstanceOf[WindowExpression] => + fromWindowCall(a.child.asInstanceOf[WindowExpression], window.child.output) + case other => + throw new UnsupportedOperationException(s"Unsupported window expression: $other") + }.asJava + + val partitionExpressions = window.partitionSpec.map(toExpression(window.child.output)).asJava + val sorts = window.orderSpec.map(toSortField(window.child.output)).asJava + relation.ConsistentPartitionWindow.builder + .input(visit(window.child)) + .addAllWindowFunctions(windowExpressions) + .addAllPartitionExpressions(partitionExpressions) + .addAllSorts(sorts) + .build() + } + private def asLong(e: Expression): Long = e match { case IntegerLiteral(limit) => limit case other => throw new UnsupportedOperationException(s"Unknown type: $other") diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala index 165d59953..f7d373155 100644 --- a/spark/src/main/scala/io/substrait/utils/Util.scala +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -29,6 +29,7 @@ object Util { * Thomas Preissler */ def crossProduct[T](lists: Seq[Seq[T]]): Seq[Seq[T]] = { + if (lists.isEmpty) return lists /** list [a, b], element 1 => list + element => [a, b, 1] */ val appendElementToList: (Seq[T], T) => Seq[T] = diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index bf4c76ed9..5d3ff29aa 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -32,16 +32,16 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } // spotless:off - val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q8", - "q11", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19", - "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", - "q30", "q31", "q32", "q33", "q37", "q38", - "q40", "q41", "q42", "q43", "q46", "q48", + val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q5", "q7", "q8", + "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19", + "q20", "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", + "q30", "q31", "q32", "q33", "q36", "q37", "q38", + "q40", "q41", "q42", "q43", "q44", "q46", "q48", "q49", "q50", "q52", "q54", "q55", "q56", "q58", "q59", - "q60", "q61", "q62", "q65", "q66", "q68", "q69", - "q71", "q73", "q76", "q79", - "q81", "q82", "q85", "q87", "q88", - "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99") + "q60", "q61", "q62", "q65", "q66", "q67", "q68", "q69", + "q70", "q71", "q73", "q76", "q77", "q79", + "q80", "q81", "q82", "q85", "q86", "q87", "q88", + "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") // spotless:on tpcdsQueries.foreach { @@ -57,7 +57,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } } - ignore("window") { + test("window") { val qry = s"""(SELECT | item_sk, | rank() diff --git a/spark/src/test/scala/io/substrait/spark/WindowPlan.scala b/spark/src/test/scala/io/substrait/spark/WindowPlan.scala new file mode 100644 index 000000000..1533e5b5b --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/WindowPlan.scala @@ -0,0 +1,79 @@ +package io.substrait.spark + +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + +import org.apache.spark.sql.TPCBase +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * These tests are based on the examples in the Spark documentation on Window functions. + * https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-window.html + */ +class WindowPlan extends TPCBase with SubstraitPlanTestBase { + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + override protected def createTables(): Unit = { + spark.sql( + "CREATE TABLE employees (name STRING, dept STRING, salary INT, age INT) USING parquet;") + } + + override protected def dropTables(): Unit = { + spark.sessionState.catalog.dropTable(TableIdentifier("employees"), true, true) + } + + test("rank") { + val query = + """ + |SELECT name, dept, salary, RANK() OVER (PARTITION BY dept ORDER BY salary) AS rank FROM employees + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("cume_dist") { + val query = + """ + |SELECT name, dept, age, CUME_DIST() OVER (PARTITION BY dept ORDER BY age + | RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cume_dist FROM employees + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("aggregate") { + val query = + """ + |SELECT name, dept, salary, MIN(salary) OVER (PARTITION BY dept ORDER BY salary) AS min + | FROM employees + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("lag/lead") { + val query = + """ + |SELECT name, salary, + | LAG(salary) OVER (PARTITION BY dept ORDER BY salary) AS lag, + | LEAD(salary, 1, 0) OVER (PARTITION BY dept ORDER BY salary) AS lead + | FROM employees; + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("different partitions") { + val query = + """ + |SELECT name, salary, + | LAG(salary) OVER (PARTITION BY dept ORDER BY salary) AS lag, + | LEAD(age, 1, 0) OVER (PARTITION BY salary ORDER BY age) AS lead + | FROM employees; + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } +}