Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-955] Support window function lag (#1056)
Browse files Browse the repository at this point in the history
* Initial commit

* Implement lag based on window rank

* Set lag value for non-boundary case

* Refine the code

* Fix compile issue

* Fix issue caused by rebasing

* Disable string type

* Fix bugs

* Correct the setting for validity

* Add a native test

* Change test case

* Pass lag function options to native function

* Support null literal

* Sort according to orderSpec

* Consider orderSpec in ut

* Let non-literal offset or default value fall back

* Format the code

* Fix native ut issue

* Support StringType input and refactor the code
  • Loading branch information
PHILO-HE authored Oct 11, 2022
1 parent 3c0db9e commit b1dc522
Show file tree
Hide file tree
Showing 6 changed files with 541 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package com.intel.oap.execution
import java.util.concurrent.TimeUnit
import com.google.flatbuffers.FlatBufferBuilder
import com.intel.oap.GazellePluginConfig
import com.intel.oap.expression.{CodeGeneration, ConverterUtils}
import com.intel.oap.expression.{CodeGeneration, ColumnarLiteral, ConverterUtils}
import com.intel.oap.vectorized.{ArrowWritableColumnVector, CloseableColumnBatchIterator, ExpressionEvaluator}
import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
Expand All @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, KnownFloatingPointNormalized, Descending, Expression, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, RowNumber, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, KnownFloatingPointNormalized, Descending, Expression, Lag, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, RowNumber, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand Down Expand Up @@ -216,6 +216,32 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
case Some(false) => "row_number_asc"
case None => "row_number_asc"
}
case lag: Lag =>
if (!lag.children(1).isInstanceOf[Literal] ||
!lag.children(2).isInstanceOf[Literal]) {
throw new UnsupportedOperationException("Non-literal offset or default value" +
" is NOT supported for columnar lag function!")
}
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("lag: clashed rank order found")
}
}
desc match {
case Some(true) => "lag_desc"
case Some(false) => "lag_asc"
case None => "lag_asc"
}
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
if (name.startsWith("row_number")) {
Expand Down Expand Up @@ -247,8 +273,27 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
List(TreeBuilder.makeField(
Field.nullable(attr.name,
CodeGeneration.getResultType(attr.dataType)))).toList.asJava,
NoneType.NONE_TYPE
NoneType.NONE_TYPE
)
case (n, f) if n.startsWith("lag") =>
TreeBuilder.makeFunction(n,
f.children.flatMap {
case a: AttributeReference =>
val attr = ConverterUtils.getAttrFromExpr(a)
Some(TreeBuilder.makeField(
Field.nullable(attr.name,
CodeGeneration.getResultType(attr.dataType))))
case lit: Literal =>
val literalNode = lit match {
case lit if lit.value == null =>
// Meaningless type for null. No need to care about it.
TreeBuilder.makeNull(ArrowType.Utf8.INSTANCE)
case lit =>
val (node, _) = new ColumnarLiteral(lit).doColumnarCodeGen(null)
node
}
Some(literalNode)
}.toList.asJava, NoneType.NONE_TYPE)
case (n, f) =>
TreeBuilder.makeFunction(n,
f.children
Expand All @@ -270,7 +315,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
}.toList.asJava,
NoneType.NONE_TYPE)
}
// TODO(yuan): using ConverterUtils.getAttrFromExpr
// TODO(yuan): using ConverterUtils.getAttrFromExpr
val groupingExpressions: Seq[AttributeReference] = partitionSpec.map{
case a: AttributeReference =>
ConverterUtils.getAttrFromExpr(a)
Expand All @@ -281,21 +326,43 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
case n: KnownFloatingPointNormalized =>
ConverterUtils.getAttrFromExpr(n.child)
case nomatch =>
throw new IllegalStateException()
throw new IllegalStateException("Not matched for getting partition expr!")
}.filter(_ != null)

val gPartitionSpec = TreeBuilder.makeFunction("partitionSpec",
groupingExpressions.map(e => TreeBuilder.makeField(
Field.nullable(e.name,
CodeGeneration.getResultType(e.dataType)))).toList.asJava,
NoneType.NONE_TYPE)

val orderExpressions: Seq[AttributeReference] = orderSpec.map(
od => od.child match {
case a: AttributeReference =>
ConverterUtils.getAttrFromExpr(a)
case c: Cast if c.child.isInstanceOf[AttributeReference] =>
ConverterUtils.getAttrFromExpr(c)
case _: Cast | _ : Literal =>
null
case n: KnownFloatingPointNormalized =>
ConverterUtils.getAttrFromExpr(n.child)
case nomatch =>
throw new IllegalStateException("Not matched for getting order expr!")
}
).filter(_ != null)

val gOrderSpec = TreeBuilder.makeFunction("orderSpec",
orderExpressions.map(e => TreeBuilder.makeField(
Field.nullable(e.name,
CodeGeneration.getResultType(e.dataType)))).toList.asJava,
NoneType.NONE_TYPE)

// Workaround:
// Gandiva doesn't support serializing Struct type so far. Use a fake Binary type instead.
val returnType = ArrowType.Binary.INSTANCE
val fieldType = new FieldType(false, returnType, null)
val resultField = new Field("window_res", fieldType,
windowFunctions.map {
case (row_number_func, f) if row_number_func.startsWith("row_number")=>
case (row_number_func, f) if row_number_func.startsWith("row_number") =>
// row_number will return int32 based indicies
new ArrowType.Int(32, true)
case (_, f) =>
Expand All @@ -305,7 +372,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
}.asJava)

val window = TreeBuilder.makeFunction("window",
(gWindowFunctions.toList ++ List(gPartitionSpec)).asJava, returnType)
(gWindowFunctions.toList ++ List(gPartitionSpec) ++ List(gOrderSpec)).asJava, returnType)

val evaluator = new ExpressionEvaluator()
val resultSchema = new Schema(resultField.getChildren)
Expand Down
42 changes: 34 additions & 8 deletions native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ arrow::Status ExprVisitor::MakeWindow(
child_func_name == "min" || child_func_name == "max" ||
child_func_name == "count" || child_func_name == "count_literal" ||
child_func_name == "rank_asc" || child_func_name == "rank_desc" ||
child_func_name == "row_number_desc" || child_func_name == "row_number_asc") {
child_func_name == "row_number_desc" || child_func_name == "row_number_asc" ||
child_func_name == "lag_desc" || child_func_name == "lag_asc") {
window_functions.push_back(child_function);
} else if (child_func_name == "partitionSpec") {
partition_spec = child_function;
Expand Down Expand Up @@ -455,16 +456,30 @@ arrow::Status ExprVisitor::MakeExprVisitorImpl(
std::vector<std::shared_ptr<arrow::Field>> ret_fields, ExprVisitor* p) {
std::vector<std::string> window_function_names;
std::vector<std::vector<gandiva::FieldPtr>> function_param_fields;
std::vector<std::shared_ptr<gandiva::LiteralNode>> lag_options;
for (auto window_function : window_functions) {
std::string window_function_name = window_function->descriptor()->name();
std::vector<gandiva::FieldPtr> function_param_fields_of_each;
for (std::shared_ptr<gandiva::Node> child : window_function->children()) {
std::shared_ptr<gandiva::FieldNode> field =
std::dynamic_pointer_cast<gandiva::FieldNode>(child);
if (field == nullptr) {
continue;
}
// Specially handling for lag function to get the offset & default value.
if (window_function_name.find("lag") != std::string::npos) {
auto field = std::dynamic_pointer_cast<gandiva::FieldNode>(
window_function->children().at(0));
function_param_fields_of_each.push_back(field->field());
auto offset_node = std::dynamic_pointer_cast<gandiva::LiteralNode>(
window_function->children().at(1));
lag_options.push_back(offset_node);
auto default_node = std::dynamic_pointer_cast<gandiva::LiteralNode>(
window_function->children().at(2));
lag_options.push_back(default_node);
} else {
for (std::shared_ptr<gandiva::Node> child : window_function->children()) {
std::shared_ptr<gandiva::FieldNode> field =
std::dynamic_pointer_cast<gandiva::FieldNode>(child);
if (field == nullptr) {
continue;
}
function_param_fields_of_each.push_back(field->field());
}
}
window_function_names.push_back(window_function_name);
function_param_fields.push_back(function_param_fields_of_each);
Expand All @@ -475,14 +490,25 @@ arrow::Status ExprVisitor::MakeExprVisitorImpl(
std::dynamic_pointer_cast<gandiva::FieldNode>(child);
partition_fields.push_back(field->field());
}
std::vector<gandiva::FieldPtr> order_fields;
// order_spec is not required for all window functions. It can be null.
if (order_spec != nullptr) {
for (std::shared_ptr<gandiva::Node> child : order_spec->children()) {
std::shared_ptr<gandiva::FieldNode> field =
std::dynamic_pointer_cast<gandiva::FieldNode>(child);
order_fields.push_back(field->field());
}
}

std::vector<std::shared_ptr<arrow::DataType>> return_types;
for (auto return_field : ret_fields) {
std::shared_ptr<arrow::DataType> type = return_field->type();
return_types.push_back(type);
}
// todo order_spec frame_spec
RETURN_NOT_OK(WindowVisitorImpl::Make(p, window_function_names, return_types,
function_param_fields, partition_fields, &impl_));
function_param_fields, partition_fields,
order_fields, lag_options, &impl_));
return arrow::Status();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,29 @@ class WindowVisitorImpl : public ExprVisitorImpl {
WindowVisitorImpl(ExprVisitor* p, std::vector<std::string> window_function_names,
std::vector<std::shared_ptr<arrow::DataType>> return_types,
std::vector<std::vector<gandiva::FieldPtr>> function_param_fields,
std::vector<gandiva::FieldPtr> partition_fields)
std::vector<gandiva::FieldPtr> partition_fields,
std::vector<gandiva::FieldPtr> order_fields,
std::vector<std::shared_ptr<gandiva::LiteralNode>> lag_options)
: ExprVisitorImpl(p) {
this->window_function_names_ = window_function_names;
this->return_types_ = return_types,
this->function_param_fields_ = function_param_fields;
this->partition_fields_ = partition_fields;
this->order_fields_ = order_fields;
this->lag_options_ = lag_options;
}

static arrow::Status Make(
ExprVisitor* p, std::vector<std::string> window_function_names,
std::vector<std::shared_ptr<arrow::DataType>> return_types,
std::vector<std::vector<gandiva::FieldPtr>> function_param_fields,
std::vector<gandiva::FieldPtr> partition_fields,
std::vector<gandiva::FieldPtr> order_fields,
std::vector<std::shared_ptr<gandiva::LiteralNode>> lag_options,
std::shared_ptr<ExprVisitorImpl>* out) {
auto impl = std::make_shared<WindowVisitorImpl>(
p, window_function_names, return_types, function_param_fields, partition_fields);
p, window_function_names, return_types, function_param_fields, partition_fields,
order_fields, lag_options);
*out = impl;
return arrow::Status::OK();
}
Expand All @@ -145,6 +152,16 @@ class WindowVisitorImpl : public ExprVisitorImpl {

RETURN_NOT_OK(extra::EncodeArrayKernel::Make(&p_->ctx_, &partition_kernel_));

std::vector<std::shared_ptr<arrow::DataType>> order_type_list;
for (auto order_field : order_fields_) {
std::shared_ptr<arrow::Field> field;
int col_id;
RETURN_NOT_OK(
GetColumnIdAndFieldByName(p_->schema_, order_field->name(), &col_id, &field));
order_field_ids_.push_back(col_id);
order_type_list.push_back(field->type());
}

for (int func_id = 0; func_id < window_function_names_.size(); func_id++) {
std::string window_function_name = window_function_names_.at(func_id);
std::shared_ptr<arrow::DataType> return_type = return_types_.at(func_id);
Expand Down Expand Up @@ -185,6 +202,14 @@ class WindowVisitorImpl : public ExprVisitorImpl {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, false));
} else if (window_function_name == "lag_desc") {
RETURN_NOT_OK(extra::WindowLagKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, lag_options_,
&function_kernel, true, return_type, order_type_list));
} else if (window_function_name == "lag_asc") {
RETURN_NOT_OK(extra::WindowLagKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, lag_options_,
&function_kernel, false, return_type, order_type_list));
} else {
return arrow::Status::Invalid("window function not supported: " +
window_function_name);
Expand Down Expand Up @@ -266,6 +291,18 @@ class WindowVisitorImpl : public ExprVisitorImpl {
in3.push_back(val); // single column
}
in3.push_back(out2); // group_ids
// field for sorting.
for (auto col_id : order_field_ids_) {
if (col_id >= p_->in_record_batch_->num_columns()) {
return arrow::Status::Invalid(
"WindowVisitorImpl: Function parameter number overflows defined "
"column "
"count");
}
auto col = p_->in_record_batch_->column(col_id);
in3.push_back(col);
}

#ifdef DEBUG
std::cout << "[window kernel] Calling "
"function_kernels_.at(func_id)->Evaluate(in3) on batch... "
Expand Down Expand Up @@ -334,9 +371,12 @@ class WindowVisitorImpl : public ExprVisitorImpl {
std::vector<std::string> window_function_names_;
std::vector<std::shared_ptr<arrow::DataType>> return_types_;
std::vector<std::vector<gandiva::FieldPtr>> function_param_fields_;
std::vector<std::shared_ptr<gandiva::LiteralNode>> lag_options_;
std::vector<gandiva::FieldPtr> partition_fields_;
std::vector<gandiva::FieldPtr> order_fields_;
std::vector<std::vector<int>> function_param_field_ids_;
std::vector<int> partition_field_ids_;
std::vector<int> order_field_ids_;
std::shared_ptr<extra::KernalBase> concat_kernel_;
std::shared_ptr<extra::KernalBase> partition_kernel_;
std::vector<std::shared_ptr<extra::KernalBase>> function_kernels_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <arrow/array.h>
#include <arrow/status.h>
#include <arrow/type_fwd.h>
#include <arrow/type_traits.h>
#include <arrow/util/iterator.h>
#include <gandiva/node.h>
#include <gandiva/tree_expr_builder.h>
Expand Down Expand Up @@ -333,7 +334,7 @@ class WindowRankKernel : public KernalBase {
std::shared_ptr<ArrayItemIndexS> i,
std::shared_ptr<ArrayItemIndexS> j, bool* out);

private:
protected:
std::shared_ptr<WindowSortKernel::Impl> sorter_;
arrow::compute::ExecContext* ctx_ = nullptr;
std::vector<ArrayList> input_cache_;
Expand All @@ -342,6 +343,42 @@ class WindowRankKernel : public KernalBase {
bool is_row_number_;
};

class WindowLagKernel : public WindowRankKernel {
public:
WindowLagKernel(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc, int offset,
std::shared_ptr<gandiva::LiteralNode> default_node,
std::shared_ptr<arrow::DataType> return_type,
std::vector<std::shared_ptr<arrow::DataType>> order_type_list);

static arrow::Status Make(
arrow::compute::ExecContext* ctx, std::string function_name,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::vector<std::shared_ptr<gandiva::LiteralNode>> lag_options,
std::shared_ptr<KernalBase>* out, bool desc,
std::shared_ptr<arrow::DataType> return_type,
std::vector<std::shared_ptr<arrow::DataType>> order_type_list);

arrow::Status Finish(ArrayList* out) override;

template <typename VALUE_TYPE, typename CType, typename BuilderType, typename ArrayType,
typename OP>
arrow::Status HandleSortedPartition(
std::vector<ArrayList>& values,
std::vector<std::shared_ptr<arrow::Int32Array>>& group_ids, int32_t max_group_id,
std::vector<std::vector<std::shared_ptr<ArrayItemIndexS>>>& sorted_partitions,
ArrayList* out, OP op);

private:
// positive offset means lag to the above row from the current row with an offset.
// negative offset means lag to the below row from the current row with an offset.
int offset_;
std::shared_ptr<gandiva::LiteralNode> default_node_;
std::shared_ptr<arrow::DataType> return_type_;
std::vector<std::shared_ptr<arrow::DataType>> order_type_list_;
};

/*class UniqueArrayKernel : public KernalBase {
public:
static arrow::Status Make(arrow::compute::ExecContext* ctx,
Expand Down
Loading

0 comments on commit b1dc522

Please sign in to comment.