diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala index 1a670e73a..2793e2d40 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala @@ -103,7 +103,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], try { breakable { for (func <- validateWindowFunctions()) { - if (func._1 == "row_number") { + if (func._1.startsWith("row_number")) { allLiteral = false break } @@ -196,10 +196,29 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], case None => "rank_asc" } case rw: RowNumber => - "row_number" + val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) { + (desc, s) => + val currentDesc = s.direction match { + case Ascending => false + case Descending => true + case _ => throw new IllegalStateException + } + if (desc.isEmpty) { + Some(currentDesc) + } else if (currentDesc == desc.get) { + Some(currentDesc) + } else { + throw new UnsupportedOperationException("row_number: clashed rank order found") + } + } + desc match { + case Some(true) => "row_number_desc" + case Some(false) => "row_number_asc" + case None => "row_number_asc" + } case f => throw new UnsupportedOperationException("unsupported window function: " + f) } - if (name == "row_number") { + if (name.startsWith("row_number")) { (name, orderSpec.head.child) } else { (name, func) @@ -221,10 +240,10 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], } else { val prev1 = System.nanoTime() val gWindowFunctions = windowFunctions.map { - case ("row_number", spec) => + case (row_number_func, spec) => //TODO(): should get attr from orderSpec val attr = ConverterUtils.getAttrFromExpr(orderSpec.head.child, true) - TreeBuilder.makeFunction("row_number", + TreeBuilder.makeFunction(row_number_func, List(TreeBuilder.makeField( Field.nullable(attr.name, CodeGeneration.getResultType(attr.dataType)))).toList.asJava, @@ -262,8 +281,12 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val returnType = ArrowType.Binary.INSTANCE val fieldType = new FieldType(false, returnType, null) val resultField = new Field("window_res", fieldType, - windowFunctions.map { case (_, f) => - CodeGeneration.getResultType(f.dataType) + windowFunctions.map { + case (row_number_func, f) if row_number_func.startsWith("row_number")=> + // row_number will return int32 based indicies + new ArrowType.Int(32, true) + case (_, f) => + CodeGeneration.getResultType(f.dataType) }.zipWithIndex.map { case (t, i) => Field.nullable(s"window_res_" + i, t) }.asJava) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc index a0b5e6aae..8a0525bc4 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc @@ -282,7 +282,7 @@ arrow::Status ExprVisitor::MakeWindow( child_func_name == "min" || child_func_name == "max" || child_func_name == "count" || child_func_name == "count_literal" || child_func_name == "rank_asc" || child_func_name == "rank_desc" || - child_func_name == "row_number") { + child_func_name == "row_number_desc" || child_func_name == "row_number_asc") { window_functions.push_back(child_function); } else if (child_func_name == "partitionSpec") { partition_spec = child_function; diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h index edd1a1f42..9ffa2a196 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h @@ -177,10 +177,14 @@ class WindowVisitorImpl : public ExprVisitorImpl { RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name, function_param_type_list, &function_kernel, true)); - } else if (window_function_name == "row_number") { + } else if (window_function_name == "row_number_desc") { RETURN_NOT_OK(extra::WindowRankKernel::Make( &p_->ctx_, window_function_name, function_param_type_list, &function_kernel, - true /*FIXME: force decending*/)); + true)); + } else if (window_function_name == "row_number_asc") { + RETURN_NOT_OK(extra::WindowRankKernel::Make( + &p_->ctx_, window_function_name, function_param_type_list, &function_kernel, + false)); } else { return arrow::Status::Invalid("window function not supported: " + window_function_name); diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc index 94acdda4c..028778a72 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc @@ -297,7 +297,7 @@ arrow::Status WindowRankKernel::Make( throw JniPendingException("Window Sort codegen failed"); } } - if (function_name == "row_number") { + if (function_name.rfind("row_number", 0) == 0) { *out = std::make_shared(ctx, type_list, sorter, desc, true); } else { *out = std::make_shared(ctx, type_list, sorter, desc);