Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-1120] Support sum window function with order by statement (#1122)
Browse files Browse the repository at this point in the history
* Initial commit

* Use correct type for getting input array

* Refine the code

* Refactor the code

* Do some refactor for rank/row_number

* Refine the code layout

* Fix ut issues

* Fall back for decimal type

* Ignore a spark ut

* Reformat some scala code

* Reformat native code
  • Loading branch information
PHILO-HE authored Oct 17, 2022
1 parent 75f685c commit e143acb
Show file tree
Hide file tree
Showing 8 changed files with 619 additions and 332 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.intel.oap.execution

import java.util.concurrent.TimeUnit

import com.google.flatbuffers.FlatBufferBuilder
import com.intel.oap.GazellePluginConfig
import com.intel.oap.expression.{CodeGeneration, ColumnarLiteral, ConverterUtils}
Expand All @@ -28,8 +29,8 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, CurrentRow, Descending, Expression, KnownFloatingPointNormalized, Lag, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, RowNumber, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, KnownFloatingPointNormalized, Descending, Expression, Lag, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, RowNumber, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand Down Expand Up @@ -132,6 +133,26 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
}
}

def checkSortFunctionFrame(windowSpec: WindowSpecDefinition): Unit = {
if (windowSpec.orderSpec.nonEmpty) {
// we only support default frame when order is specified, i.e.,
// from UnboundedPreceding (lower bound) to CurrentRow (upper bound).
windowSpec.frameSpecification match {
case s: SpecifiedWindowFrame =>
s.lower match {
case UnboundedPreceding =>
case _ => throw new UnsupportedOperationException("Only UnboundedPreceding" +
" is supported as lower bound!")
}
s.upper match {
case CurrentRow =>
case _ => throw new UnsupportedOperationException("Only CurrentRow is supported" +
" as upper bound!")
}
}
}
}

def checkRankSpec(windowSpec: WindowSpecDefinition): Unit = {
// leave it empty for now
}
Expand All @@ -155,8 +176,40 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
case (expr, func) =>
val name = func match {
case _: Sum =>
checkAggFunctionSpec(expr.windowSpec)
"sum"
// Allow "order by" for sum aggregation.
// checkAggFunctionSpec(expr.windowSpec)
// For order by a literal, e.g., order by 2, the behavior is as same as
// that for no order by.
if (orderSpec.isEmpty || orderSpec.head.child.isInstanceOf[Literal]) {
"sum"
} else {
// Only default frame is used in order by case.
checkSortFunctionFrame(expr.windowSpec)
// TODO: support decimal type.
if (expr.windowFunction.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException("Decimal type is not supported!")
}
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("sum: clashed rank order found")
}
}
desc match {
case Some(true) => "sum_desc"
case Some(false) => "sum_asc"
case None => "sum_asc"
}
}
case _: Average =>
checkAggFunctionSpec(expr.windowSpec)
"avg"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,17 +772,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}
}

test("Window spill with more than the inMemoryThreshold and spillThreshold") {
val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")

withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1",
SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
assertSpilled(sparkContext, "select") {
df.select($"key", sum("value").over(window)).collect()
}
}
}
// This is for testing vanilla spark's behavior, so we ignore it.
// test("Window spill with more than the inMemoryThreshold and spillThreshold") {
// val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
// val window = Window.partitionBy($"key").orderBy($"value")
//
// withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1",
// SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
// assertSpilled(sparkContext, "select") {
// df.select($"key", sum("value").over(window)).collect()
// }
// }
// }

test("SPARK-21258: complex object in combination with spilling") {
// Make sure we trigger the spilling path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,36 +470,37 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSparkSession {
Row(1, 3, null) :: Row(2, null, 4) :: Nil)
}

test("test with low buffer spill threshold") {
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
nums.createOrReplaceTempView("nums")

val expected =
Row(1, 1, 1) ::
Row(0, 2, 3) ::
Row(1, 3, 6) ::
Row(0, 4, 10) ::
Row(1, 5, 15) ::
Row(0, 6, 21) ::
Row(1, 7, 28) ::
Row(0, 8, 36) ::
Row(1, 9, 45) ::
Row(0, 10, 55) :: Nil

val actual = sql(
"""
|SELECT y, x, sum(x) OVER w1 AS running_sum
|FROM nums
|WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
""".stripMargin)

withSQLConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1",
WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") {
assertSpilled(sparkContext, "test with low buffer spill threshold") {
checkAnswer(actual, expected)
}
}

spark.catalog.dropTempView("nums")
}
// This is for testing vanilla spark's behavior. So we ignore it.
// test("test with low buffer spill threshold") {
// val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
// nums.createOrReplaceTempView("nums")
//
// val expected =
// Row(1, 1, 1) ::
// Row(0, 2, 3) ::
// Row(1, 3, 6) ::
// Row(0, 4, 10) ::
// Row(1, 5, 15) ::
// Row(0, 6, 21) ::
// Row(1, 7, 28) ::
// Row(0, 8, 36) ::
// Row(1, 9, 45) ::
// Row(0, 10, 55) :: Nil
//
// val actual = sql(
// """
// |SELECT y, x, sum(x) OVER w1 AS running_sum
// |FROM nums
// |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
// """.stripMargin)
//
// withSQLConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "1",
// WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "2") {
// assertSpilled(sparkContext, "test with low buffer spill threshold") {
// checkAnswer(actual, expected)
// }
// }
//
// spark.catalog.dropTempView("nums")
// }
}
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ arrow::Status ExprVisitor::MakeWindow(
child_func_name == "count" || child_func_name == "count_literal" ||
child_func_name == "rank_asc" || child_func_name == "rank_desc" ||
child_func_name == "row_number_desc" || child_func_name == "row_number_asc" ||
child_func_name == "lag_desc" || child_func_name == "lag_asc") {
child_func_name == "lag_desc" || child_func_name == "lag_asc" ||
child_func_name == "sum_asc" || child_func_name == "sum_desc") {
window_functions.push_back(child_function);
} else if (child_func_name == "partitionSpec") {
partition_spec = child_function;
Expand Down
34 changes: 22 additions & 12 deletions native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,28 +180,30 @@ class WindowVisitorImpl : public ExprVisitorImpl {
}
function_param_field_ids_.push_back(function_param_field_ids_of_each);

// For window aggregation with no order by statement.
if (window_function_name == "sum" || window_function_name == "avg" ||
window_function_name == "min" || window_function_name == "max" ||
window_function_name == "count" || window_function_name == "count_literal") {
RETURN_NOT_OK(extra::WindowAggregateFunctionKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, return_type,
&function_kernel));
} else if (window_function_name == "rank_asc") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, false));
RETURN_NOT_OK(extra::WindowRankKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
false, order_type_list));
} else if (window_function_name == "rank_desc") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, true));
RETURN_NOT_OK(extra::WindowRankKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
true, order_type_list));
} else if (window_function_name == "row_number_desc") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, true));
// For row_number, the order type list is as same as function_param_type_list.
RETURN_NOT_OK(extra::WindowRankKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
true, function_param_type_list));
} else if (window_function_name == "row_number_asc") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, false));
RETURN_NOT_OK(extra::WindowRankKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
false, function_param_type_list));
} else if (window_function_name == "lag_desc") {
RETURN_NOT_OK(extra::WindowLagKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, lag_options_,
Expand All @@ -210,6 +212,14 @@ class WindowVisitorImpl : public ExprVisitorImpl {
RETURN_NOT_OK(extra::WindowLagKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, lag_options_,
&function_kernel, false, return_type, order_type_list));
} else if (window_function_name == "sum_desc") {
RETURN_NOT_OK(extra::WindowSumKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
true, return_type, order_type_list));
} else if (window_function_name == "sum_asc") {
RETURN_NOT_OK(extra::WindowSumKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
false, return_type, order_type_list));
} else {
return arrow::Status::Invalid("window function not supported: " +
window_function_name);
Expand Down
83 changes: 65 additions & 18 deletions native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,38 +312,56 @@ class HashAggregateKernel : public KernalBase {
arrow::compute::ExecContext* ctx_ = nullptr;
};

class WindowRankKernel : public KernalBase {
// An abstract base class for window functions requiring sort.
class WindowSortBase : public KernalBase {
public:
WindowRankKernel(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc,
bool is_row_number = false);
static arrow::Status Make(arrow::compute::ExecContext* ctx, std::string function_name,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<KernalBase>* out, bool desc);
arrow::Status Evaluate(ArrayList& in) override;
arrow::Status Finish(ArrayList* out) override;

arrow::Status SortToIndicesPrepare(std::vector<ArrayList> values);
arrow::Status SortToIndicesFinish(
std::vector<std::shared_ptr<ArrayItemIndexS>> elements_to_sort,
std::vector<std::shared_ptr<ArrayItemIndexS>>* offsets);

template <typename ArrayType>
arrow::Status AreTheSameValue(const std::vector<ArrayList>& values, int column,
std::shared_ptr<ArrayItemIndexS> i,
std::shared_ptr<ArrayItemIndexS> j, bool* out);
// For finish preparation, like sorting input fir each group.
arrow::Status prepareFinish();

protected:
std::shared_ptr<WindowSortKernel::Impl> sorter_;
arrow::compute::ExecContext* ctx_ = nullptr;
std::vector<ArrayList> input_cache_;
std::vector<std::shared_ptr<arrow::DataType>> type_list_;
bool desc_;

std::vector<std::shared_ptr<arrow::DataType>> order_type_list_;

std::vector<ArrayList> values_; // The window function input.
std::vector<std::shared_ptr<arrow::Int32Array>> group_ids_;
int32_t max_group_id_ = 0;
std::vector<std::vector<std::shared_ptr<ArrayItemIndexS>>> sorted_partitions_;
};

class WindowRankKernel : public WindowSortBase {
public:
WindowRankKernel(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc,
std::vector<std::shared_ptr<arrow::DataType>> order_type_list,
bool is_row_number = false);
static arrow::Status Make(
arrow::compute::ExecContext* ctx, std::string function_name,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<KernalBase>* out, bool desc,
std::vector<std::shared_ptr<arrow::DataType>> order_type_list);
arrow::Status Finish(ArrayList* out) override;

template <typename ArrayType>
arrow::Status AreTheSameValue(const std::vector<ArrayList>& values, int column,
std::shared_ptr<ArrayItemIndexS> i,
std::shared_ptr<ArrayItemIndexS> j, bool* out);

protected:
bool is_row_number_;
};

class WindowLagKernel : public WindowRankKernel {
class WindowLagKernel : public WindowSortBase {
public:
WindowLagKernel(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
Expand All @@ -370,13 +388,42 @@ class WindowLagKernel : public WindowRankKernel {
std::vector<std::vector<std::shared_ptr<ArrayItemIndexS>>>& sorted_partitions,
ArrayList* out, OP op);

private:
protected:
// positive offset means lag to the above row from the current row with an offset.
// negative offset means lag to the below row from the current row with an offset.
std::shared_ptr<arrow::DataType> return_type_;
int offset_;
std::shared_ptr<gandiva::LiteralNode> default_node_;
};

// For sum window function with sort needed (has to consider window frame).
class WindowSumKernel : public WindowSortBase {
public:
WindowSumKernel(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc,
std::shared_ptr<arrow::DataType> return_type,
std::vector<std::shared_ptr<arrow::DataType>> order_type_list);

static arrow::Status Make(
arrow::compute::ExecContext* ctx, std::string function_name,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<KernalBase>* out, bool desc,
std::shared_ptr<arrow::DataType> return_type,
std::vector<std::shared_ptr<arrow::DataType>> order_type_list);

arrow::Status Finish(ArrayList* out) override;

template <typename VALUE_TYPE, typename CType, typename BuilderType, typename ArrayType,
typename OP>
arrow::Status HandleSortedPartition(
std::vector<ArrayList>& values,
std::vector<std::shared_ptr<arrow::Int32Array>>& group_ids, int32_t max_group_id,
std::vector<std::vector<std::shared_ptr<ArrayItemIndexS>>>& sorted_partitions,
ArrayList* out, OP op);

protected:
std::shared_ptr<arrow::DataType> return_type_;
std::vector<std::shared_ptr<arrow::DataType>> order_type_list_;
};

/*class UniqueArrayKernel : public KernalBase {
Expand Down
Loading

0 comments on commit e143acb

Please sign in to comment.