From e09a5e68a163740ad48432db0f669ae21588a866 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Sun, 18 Jul 2021 12:01:33 +0800 Subject: [PATCH 1/3] [NSE-383] Release SMJ input data immediately after being used Also, add switch option spark.oap.sql.columnar.sortmergejoin.lazyread to Spark config. --- .../oap/vectorized/ExpressionEvaluator.java | 7 + .../ExpressionEvaluatorJniWrapper.java | 11 + .../com/intel/oap/ColumnarPluginConfig.scala | 4 + .../ColumnarWholeStageCodegenExec.scala | 61 +++- .../com/intel/oap/tpc/ds/TPCDSSuite.scala | 7 +- .../scala/com/intel/oap/tpc/h/TPCHSuite.scala | 12 +- .../codegen/arrow_compute/code_generator.h | 7 + .../src/codegen/arrow_compute/expr_visitor.cc | 7 + .../src/codegen/arrow_compute/expr_visitor.h | 6 + .../codegen/arrow_compute/expr_visitor_impl.h | 12 +- .../arrow_compute/ext/code_generator_base.h | 6 + .../codegen/arrow_compute/ext/kernels_ext.cc | 73 +++-- .../codegen/arrow_compute/ext/kernels_ext.h | 8 +- .../cpp/src/codegen/code_generator.h | 1 + .../cpp/src/codegen/common/relation.cc | 38 +++ .../cpp/src/codegen/common/relation_column.h | 268 +++++++++++++++++- .../cpp/src/codegen/common/sort_relation.h | 222 +++++++++++++-- .../src/codegen/compute_ext/code_generator.h | 2 + .../cpp/src/codegen/gandiva/code_generator.h | 1 + native-sql-engine/cpp/src/jni/jni_wrapper.cc | 84 ++++++ .../cpp/src/tests/arrow_compute_test_wscg.cc | 44 ++- 21 files changed, 808 insertions(+), 73 deletions(-) diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluator.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluator.java index aa2458ff4..4dc707704 100644 --- a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluator.java +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluator.java @@ -19,6 +19,7 @@ import com.intel.oap.ColumnarPluginConfig; import com.intel.oap.spark.sql.execution.datasources.v2.arrow.Spiller; +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; import org.apache.arrow.memory.ArrowBuf; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -143,6 +144,12 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch) throws RuntimeE return evaluate(recordBatch, null); } + public void evaluate(NativeSerializedRecordBatchIterator batchItr) + throws RuntimeException, IOException { + jniWrapper.nativeEvaluateWithIterator(nativeHandler, + batchItr); + } + /** * Evaluate input data using builded native function, and output as recordBatch. */ diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluatorJniWrapper.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluatorJniWrapper.java index 0c153f7a9..066bff359 100644 --- a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluatorJniWrapper.java +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/ExpressionEvaluatorJniWrapper.java @@ -17,6 +17,8 @@ package com.intel.oap.vectorized; +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.spark.memory.MemoryConsumer; import java.io.IOException; @@ -147,6 +149,15 @@ native ArrowRecordBatchBuilder[] nativeEvaluate(long nativeHandler, int numRows, native ArrowRecordBatchBuilder[] nativeEvaluate2(long nativeHandler, byte[] bytes) throws RuntimeException; + /** + * Evaluate the expressions represented by the nativeHandler on a record batch + * iterator. Throws an exception in case of errors + * + * @param nativeHandler a iterator instance carrying input record batches + */ + native void nativeEvaluateWithIterator(long nativeHandler, + NativeSerializedRecordBatchIterator batchItr) throws RuntimeException; + /** * Get native kernel signature by the nativeHandler. * diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala index cc4c16ed9..5cc054caa 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala @@ -30,6 +30,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def getCpu(): Boolean = { val source = scala.io.Source.fromFile("/proc/cpuinfo") val lines = try source.mkString finally source.close() + return true //TODO(): check CPU flags to enable/disable AVX512 if (lines.contains("GenuineIntel")) { return true @@ -79,6 +80,9 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val enableColumnarSortMergeJoin: Boolean = conf.getConfString("spark.oap.sql.columnar.sortmergejoin", "true").toBoolean && enableCpu + val enableColumnarSortMergeJoinLazyRead: Boolean = + conf.getConfString("spark.oap.sql.columnar.sortmergejoin.lazyread", "false").toBoolean + // enable or disable columnar union val enableColumnarUnion: Boolean = conf.getConfString("spark.oap.sql.columnar.union", "true").toBoolean && enableCpu diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala index 1ecdf50c2..00a9bc0c8 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala @@ -17,12 +17,15 @@ package com.intel.oap.execution +import java.util.concurrent.TimeUnit.NANOSECONDS + import com.google.common.collect.Lists import com.intel.oap.ColumnarPluginConfig import com.intel.oap.expression._ import com.intel.oap.vectorized.{BatchIterator, ExpressionEvaluator, _} import org.apache.arrow.gandiva.expression._ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} @@ -33,10 +36,13 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils} - import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator +import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch + case class ColumnarCodegenContext(inputSchema: Schema, outputSchema: Schema, root: TreeNode) {} trait ColumnarCodegenSupport extends SparkPlan { @@ -73,6 +79,7 @@ case class ColumnarWholeStageCodegenExec(child: SparkPlan)(val codegenStageId: I val sparkConf = sparkContext.getConf val numaBindingInfo = ColumnarPluginConfig.getConf.numaBindingInfo + val enableColumnarSortMergeJoinLazyRead = ColumnarPluginConfig.getConf.enableColumnarSortMergeJoinLazyRead override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -290,6 +297,7 @@ case class ColumnarWholeStageCodegenExec(child: SparkPlan)(val codegenStageId: I var idx = 0 var curRDD = inputRDDs()(0) while (idx < buildPlans.length) { + val curPlan = buildPlans(idx)._1 val parentPlan = buildPlans(idx)._2 @@ -408,15 +416,46 @@ case class ColumnarWholeStageCodegenExec(child: SparkPlan)(val codegenStageId: I Lists.newArrayList(expression), outputSchema, true) - while (depIter.hasNext) { - val dep_cb = depIter.next() - if (dep_cb.numRows > 0) { - (0 until dep_cb.numCols).toList.foreach(i => - dep_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) - buildRelationBatchHolder += dep_cb - val dep_rb = ConverterUtils.createArrowRecordBatch(dep_cb) - cachedRelationKernel.evaluate(dep_rb) - ConverterUtils.releaseArrowRecordBatch(dep_rb) + + if (enableColumnarSortMergeJoinLazyRead) { + // Used as ABI to prevent from serializing buffer data + val serializedItr: NativeSerializedRecordBatchIterator = { + new NativeSerializedRecordBatchIterator { + + override def hasNext: Boolean = { + depIter.hasNext + } + + override def next(): Array[Byte] = { + val dep_cb = depIter.next() + if (dep_cb.numRows > 0) { + val dep_rb = ConverterUtils.createArrowRecordBatch(dep_cb) + serialize(dep_rb) + } else { + throw new IllegalStateException() + } + } + + private def serialize(batch: ArrowRecordBatch) = { + UnsafeRecordBatchSerializer.serializeUnsafe(batch) + } + + override def close(): Unit = { + } + } + } + cachedRelationKernel.evaluate(serializedItr) + } else { + while (depIter.hasNext) { + val dep_cb = depIter.next() + if (dep_cb.numRows > 0) { + (0 until dep_cb.numCols).toList.foreach(i => + dep_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain()) + buildRelationBatchHolder += dep_cb + val dep_rb = ConverterUtils.createArrowRecordBatch(dep_cb) + cachedRelationKernel.evaluate(dep_rb) + ConverterUtils.releaseArrowRecordBatch(dep_rb) + } } } dependentKernels += cachedRelationKernel @@ -570,7 +609,7 @@ case class ColumnarWholeStageCodegenExec(child: SparkPlan)(val codegenStageId: I def close = { closed = true pipelineTime += (eval_elapse + build_elapse) / 1000000 - buildRelationBatchHolder.foreach(_.close) + buildRelationBatchHolder.foreach(_.close) // fixing: ref cnt goes nagative dependentKernels.foreach(_.close) dependentKernelIterators.foreach(_.close) nativeKernel.close diff --git a/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala b/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala index 1dd9d23bb..5766f1f3e 100644 --- a/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala +++ b/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala @@ -52,6 +52,7 @@ class TPCDSSuite extends QueryTest with SharedSparkSession { .set("spark.unsafe.exceptionOnMemoryLeak", "false") .set("spark.network.io.preferDirectBufs", "false") .set("spark.sql.sources.useV1SourceList", "arrow,parquet") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") return conf } @@ -59,7 +60,7 @@ class TPCDSSuite extends QueryTest with SharedSparkSession { override def beforeAll(): Unit = { super.beforeAll() LogManager.getRootLogger.setLevel(Level.WARN) - val tGen = new TPCDSTableGen(spark, 0.01D, TPCDS_WRITE_PATH) + val tGen = new TPCDSTableGen(spark, 0.1D, TPCDS_WRITE_PATH) tGen.gen() tGen.createTables() runner = new TPCRunner(spark, TPCDS_QUERIES_RESOURCE) @@ -91,6 +92,10 @@ class TPCDSSuite extends QueryTest with SharedSparkSession { runner.runTPCQuery("q67", 1, true) } + test("smj query") { + runner.runTPCQuery("q1", 1, true) + } + test("window function with non-decimal input") { val df = spark.sql("SELECT i_item_sk, i_class_id, SUM(i_category_id)" + " OVER (PARTITION BY i_class_id) FROM item LIMIT 1000") diff --git a/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/h/TPCHSuite.scala b/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/h/TPCHSuite.scala index 928743621..72b96179d 100644 --- a/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/h/TPCHSuite.scala +++ b/native-sql-engine/core/src/test/scala/com/intel/oap/tpc/h/TPCHSuite.scala @@ -59,7 +59,7 @@ class TPCHSuite extends QueryTest with SharedSparkSession { .set("spark.executor.heartbeatInterval", "3600000") .set("spark.network.timeout", "3601s") .set("spark.oap.sql.columnar.preferColumnar", "true") - .set("spark.oap.sql.columnar.sortmergejoin", "true") + .set("spark.oap.sql.columnar.sortmergejoin", "true") .set("spark.sql.columnar.codegen.hashAggregate", "false") .set("spark.sql.columnar.sort", "true") .set("spark.sql.columnar.window", "true") @@ -96,13 +96,21 @@ class TPCHSuite extends QueryTest with SharedSparkSession { } } - ignore("q12 SMJ failure") { + test("q12 SMJ") { withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1"), ("spark.oap.sql.columnar.sortmergejoin", "true")) { runner.runTPCQuery("q12", 1, true) } } + test("q12 SMJ lazy") { + withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1"), + ("spark.oap.sql.columnar.sortmergejoin", "true"), + ("spark.oap.sql.columnar.sortmergejoin.lazyread", "true")) { + runner.runTPCQuery("q12", 1, true) + } + } + private def runMemoryUsageTest(exclusions: Array[String] = Array[String](), comment: String = ""): Unit = { val enableTPCHTests = Option(System.getenv("ENABLE_TPCH_TESTS")) if (!enableTPCHTests.exists(_.toBoolean)) { diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/code_generator.h b/native-sql-engine/cpp/src/codegen/arrow_compute/code_generator.h index a66325633..c33b08efc 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/code_generator.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/code_generator.h @@ -153,6 +153,13 @@ class ArrowComputeCodeGenerator : public CodeGenerator { return status; } + arrow::Status evaluate(arrow::RecordBatchIterator in) override { + for (auto visitor : visitor_list_) { + TIME_MICRO_OR_RAISE(eval_elapse_time_, visitor->Eval(std::move(in))); + } + return arrow::Status::OK(); + }; + arrow::Status evaluate(const std::shared_ptr& selection_in, const std::shared_ptr& in, std::vector>* out) { diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc index ad9b9ea2c..e57642f24 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc @@ -523,6 +523,13 @@ arrow::Status ExprVisitor::Eval(std::shared_ptr& in) { return arrow::Status::OK(); } +arrow::Status ExprVisitor::Eval(arrow::RecordBatchIterator in) { + input_type_ = ArrowComputeInputType::Iterator; + in_iterator_ = std::move(in); + RETURN_NOT_OK(Eval()); + return arrow::Status::OK(); +} + arrow::Status ExprVisitor::Eval() { if (return_type_ != ArrowComputeResultType::None) { #ifdef DEBUG_LEVEL_2 diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h index 4fb1c9be5..9a48fe1a9 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -40,6 +41,7 @@ class ExprVisitorImpl; using ExprVisitorMap = std::unordered_map>; using ArrayList = std::vector>; +enum class ArrowComputeInputType { Legacy, Iterator }; enum class ArrowComputeResultType { Array, Batch, BatchList, BatchIterator, None }; enum class BuilderVisitorNodeType { FunctionNode, FieldNode }; @@ -153,6 +155,8 @@ class ExprVisitor : public std::enable_shared_from_this { arrow::Status Eval(const std::shared_ptr& selection_in, const std::shared_ptr& in); arrow::Status Eval(std::shared_ptr& in); + arrow::Status Eval(const std::shared_ptr& in); + arrow::Status Eval(arrow::RecordBatchIterator in); arrow::Status Eval(); std::string GetSignature() { return signature_; } arrow::Status SetMember(const std::shared_ptr& ms); @@ -191,6 +195,8 @@ class ExprVisitor : public std::enable_shared_from_this { std::shared_ptr in_selection_array_; std::shared_ptr in_record_batch_; std::vector> in_record_batch_holder_; + ArrowComputeInputType input_type_ = ArrowComputeInputType::Legacy; + arrow::RecordBatchIterator in_iterator_; std::vector> ret_fields_; // For dual input kernels like probe diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h index 704228dd5..4f84bc6cd 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h @@ -602,11 +602,15 @@ class CachedRelationVisitorImpl : public ExprVisitorImpl { arrow::Status Eval() override { switch (p_->dependency_result_type_) { case ArrowComputeResultType::None: { - std::vector> col_list; - for (auto col : p_->in_record_batch_->columns()) { - col_list.push_back(col); + if (p_->input_type_ == ArrowComputeInputType::Iterator) { + RETURN_NOT_OK(kernel_->Evaluate(std::move(p_->in_iterator_))); + } else { + std::vector> col_list; + for (auto col : p_->in_record_batch_->columns()) { + col_list.push_back(col); + } + RETURN_NOT_OK(kernel_->Evaluate(col_list)); } - RETURN_NOT_OK(kernel_->Evaluate(col_list)); } break; default: return arrow::Status::NotImplemented( diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/code_generator_base.h b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/code_generator_base.h index cd42968fe..4fb34537b 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/code_generator_base.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/code_generator_base.h @@ -16,6 +16,8 @@ */ #pragma once +#include + #include "codegen/common/result_iterator.h" #include "precompile/array.h" @@ -37,6 +39,10 @@ class CodeGenBase { return arrow::Status::NotImplemented( "CodeGenBase Evaluate is an abstract interface."); } + virtual arrow::Status Evaluate(arrow::RecordBatchIterator in) { + return arrow::Status::NotImplemented( + "CodeGenBase Evaluate is an abstract interface."); + } virtual arrow::Status Evaluate(const ArrayList& in, const ArrayList& projected_batch) { return arrow::Status::NotImplemented( "CodeGenBase Evaluate is an abstract interface."); diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc index 0352b7bf2..5e3eb53a5 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc @@ -351,29 +351,57 @@ class CachedRelationKernel::Impl { return arrow::Status::OK(); } + arrow::Status Evaluate(arrow::RecordBatchIterator in) { + in_ = std::move(in); + is_lazy_input_ = true; + return arrow::Status::OK(); + } + arrow::Status MakeResultIterator(std::shared_ptr schema, std::shared_ptr>* out) { - std::vector> sort_relation_list; - int idx = 0; - for (auto field : result_schema_->fields()) { - std::shared_ptr col_out; - RETURN_NOT_OK(MakeRelationColumn(field->type()->id(), &col_out)); - if (cached_.size() == col_num_) { - for (auto arr : cached_[idx]) { - RETURN_NOT_OK(col_out->AppendColumn(arr)); + if (is_lazy_input_) { + std::vector> sort_relation_list; + std::shared_ptr lazy_in = + std::make_shared(std::move(in_)); + int idx = 0; + for (auto field : result_schema_->fields()) { + std::shared_ptr col_out; + RETURN_NOT_OK(MakeLazyLoadRelationColumn(field->type()->id(), &col_out)); + RETURN_NOT_OK(col_out->FromLazyBatchIterator(lazy_in, idx)); + sort_relation_list.push_back(col_out); + idx++; + } + std::vector> key_relation_list; + for (auto key_id : key_index_list_) { + key_relation_list.push_back(sort_relation_list[key_id]); + } + auto sort_relation = + SortRelation::CreateLazy(ctx_, lazy_in, key_relation_list, sort_relation_list); + *out = std::make_shared(sort_relation); + return arrow::Status::OK(); + } else { + std::vector> sort_relation_list; + int idx = 0; + for (auto field : result_schema_->fields()) { + std::shared_ptr col_out; + RETURN_NOT_OK(MakeRelationColumn(field->type()->id(), &col_out)); + if (cached_.size() == col_num_) { + for (auto arr : cached_[idx]) { + RETURN_NOT_OK(col_out->AppendColumn(arr)); + } } + sort_relation_list.push_back(col_out); + idx++; } - sort_relation_list.push_back(col_out); - idx++; - } - std::vector> key_relation_list; - for (auto key_id : key_index_list_) { - key_relation_list.push_back(sort_relation_list[key_id]); + std::vector> key_relation_list; + for (auto key_id : key_index_list_) { + key_relation_list.push_back(sort_relation_list[key_id]); + } + auto sort_relation = SortRelation::CreateLegacy( + ctx_, key_relation_list, sort_relation_list, items_total_, length_list_); + *out = std::make_shared(sort_relation); + return arrow::Status::OK(); } - auto sort_relation = std::make_shared( - ctx_, items_total_, length_list_, key_relation_list, sort_relation_list); - *out = std::make_shared(sort_relation); - return arrow::Status::OK(); } private: @@ -385,7 +413,12 @@ class CachedRelationKernel::Impl { std::vector> key_field_list_; std::shared_ptr result_schema_; + arrow::RecordBatchIterator in_; std::vector key_index_list_; + + // required by legacy method + bool is_lazy_input_ = false; + std::vector length_list_; std::vector cached_; uint64_t items_total_ = 0; @@ -424,6 +457,10 @@ arrow::Status CachedRelationKernel::Evaluate(ArrayList& in) { return impl_->Evaluate(in); } +arrow::Status CachedRelationKernel::Evaluate(arrow::RecordBatchIterator in) { + return impl_->Evaluate(std::move(in)); +} + arrow::Status CachedRelationKernel::MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) { diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h index 981a6a900..90931e253 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -51,6 +52,10 @@ class KernalBase { return arrow::Status::NotImplemented("Evaluate is abstract interface for ", kernel_name_, ", input is arrayList."); } + virtual arrow::Status Evaluate(arrow::RecordBatchIterator in) { + return arrow::Status::NotImplemented("Evaluate is abstract interface for ", + kernel_name_, ", input is iterator."); + } virtual arrow::Status Evaluate(const ArrayList& in, ArrayList* out) { return arrow::Status::NotImplemented("Evaluate is abstract interface for ", kernel_name_, @@ -232,7 +237,8 @@ class CachedRelationKernel : public KernalBase { std::shared_ptr result_schema, std::vector> key_field_list, int result_type); - arrow::Status Evaluate(ArrayList& in) override; + arrow::Status Evaluate(const ArrayList& in) override; + arrow::Status Evaluate(arrow::RecordBatchIterator in) override; arrow::Status MakeResultIterator( std::shared_ptr schema, std::shared_ptr>* out) override; diff --git a/native-sql-engine/cpp/src/codegen/code_generator.h b/native-sql-engine/cpp/src/codegen/code_generator.h index bfa49edf7..8d85d2686 100644 --- a/native-sql-engine/cpp/src/codegen/code_generator.h +++ b/native-sql-engine/cpp/src/codegen/code_generator.h @@ -35,6 +35,7 @@ class CodeGenerator { virtual arrow::Status evaluate( std::shared_ptr& in, std::vector>* out) = 0; + virtual arrow::Status evaluate(arrow::RecordBatchIterator in) = 0; virtual arrow::Status finish(std::vector>* out) = 0; virtual std::string GetSignature() { return ""; }; virtual arrow::Status finish(std::shared_ptr* out) { diff --git a/native-sql-engine/cpp/src/codegen/common/relation.cc b/native-sql-engine/cpp/src/codegen/common/relation.cc index 2cf716e79..90be2639b 100644 --- a/native-sql-engine/cpp/src/codegen/common/relation.cc +++ b/native-sql-engine/cpp/src/codegen/common/relation.cc @@ -95,6 +95,44 @@ arrow::Status MakeRelationColumn(uint32_t data_type_id, } #undef PROCESS_SUPPORTED_TYPES +/////////////////////////////////////////////////////////////////////////////////// +#define PROCESS_SUPPORTED_TYPES(PROCESS) \ + PROCESS(arrow::BooleanType) \ + PROCESS(arrow::UInt8Type) \ + PROCESS(arrow::Int8Type) \ + PROCESS(arrow::UInt16Type) \ + PROCESS(arrow::Int16Type) \ + PROCESS(arrow::UInt32Type) \ + PROCESS(arrow::Int32Type) \ + PROCESS(arrow::UInt64Type) \ + PROCESS(arrow::Int64Type) \ + PROCESS(arrow::FloatType) \ + PROCESS(arrow::DoubleType) \ + PROCESS(arrow::Date32Type) \ + PROCESS(arrow::Date64Type) \ + PROCESS(arrow::TimestampType) \ + PROCESS(arrow::Decimal128Type) \ + PROCESS(arrow::StringType) +arrow::Status MakeLazyLoadRelationColumn(uint32_t data_type_id, + std::shared_ptr* out) { + switch (data_type_id) { +#define PROCESS(InType) \ + case TypeTraits::type_id: { \ + auto typed_column = std::make_shared>(); \ + *out = std::dynamic_pointer_cast(typed_column); \ + } break; + PROCESS_SUPPORTED_TYPES(PROCESS) +#undef PROCESS + default: { + return arrow::Status::NotImplemented("MakeRelationColumn doesn't support type ", + data_type_id); + } break; + } + + return arrow::Status::OK(); +} +#undef PROCESS_SUPPORTED_TYPES + ///////////////////////////////////////////////////////////////////////// #define PROCESS_SUPPORTED_TYPES(PROCESS) \ diff --git a/native-sql-engine/cpp/src/codegen/common/relation_column.h b/native-sql-engine/cpp/src/codegen/common/relation_column.h index 4317178dc..a2fde5a5b 100644 --- a/native-sql-engine/cpp/src/codegen/common/relation_column.h +++ b/native-sql-engine/cpp/src/codegen/common/relation_column.h @@ -18,9 +18,12 @@ #pragma once #include +#include #include #include +#include +#include "iostream" #include "precompile/type_traits.h" using sparkcolumnarplugin::precompile::enable_if_number_or_decimal; @@ -28,6 +31,64 @@ using sparkcolumnarplugin::precompile::enable_if_string_like; using sparkcolumnarplugin::precompile::StringArray; using sparkcolumnarplugin::precompile::TypeTraits; +class LazyBatchIterator { + public: + LazyBatchIterator(arrow::RecordBatchIterator in) { this->in_ = std::move(in); } + + bool IsBatchReleased(int32_t batch_id) { + if (batch_id >= ref_cnts_.size()) { + return false; + } + if (ref_cnts_[batch_id] == 0) { + return true; + } + return false; + } + + std::shared_ptr GetBatch(int32_t batch_id) { + if (!AdvanceTo(batch_id)) { + return nullptr; + } + return cache_[batch_id]; + } + + bool AdvanceTo(int32_t batch_id) { + for (; current_batch_id_ <= batch_id; current_batch_id_++) { + std::shared_ptr next = in_.Next().ValueOrDie(); + if (next == nullptr) { + return false; + } + cache_.push_back(next); + ref_cnts_.push_back(0); + row_cnts_.push_back(next->num_rows()); + } + return true; + } + + int64_t GetNumRowsOfBatch(int32_t batch_id) { + if (!AdvanceTo(batch_id)) { + return -1L; + } + return row_cnts_[batch_id]; + } + + void RetainBatch(int32_t batch_id) { ref_cnts_[batch_id] = ref_cnts_[batch_id] + 1; } + + void ReleaseBatch(int32_t batch_id) { + ref_cnts_[batch_id] = ref_cnts_[batch_id] - 1; + if (ref_cnts_[batch_id] <= 0) { + cache_[batch_id] = nullptr; + } + } + + private: + arrow::RecordBatchIterator in_; + std::vector> cache_; + std::vector ref_cnts_; + std::vector row_cnts_; + int32_t current_batch_id_ = 0; +}; + class RelationColumn { public: virtual bool IsNull(int array_id, int id) = 0; @@ -35,6 +96,15 @@ class RelationColumn { virtual arrow::Status AppendColumn(std::shared_ptr in) { return arrow::Status::NotImplemented("RelationColumn AppendColumn is abstract."); }; + virtual arrow::Status FromLazyBatchIterator(std::shared_ptr in, + int field_id) { + return arrow::Status::NotImplemented("RelationColumn AppendColumn is abstract."); + }; + virtual arrow::Status AdvanceTo(int array_id) { + return arrow::Status::NotImplemented("RelationColumn Advance is abstract."); + }; + virtual int32_t Advance(int32_t array_offset) { return -1; }; + virtual arrow::Status ReleaseArray(int array_id) = 0; virtual arrow::Status GetArrayVector(std::vector>* out) { return arrow::Status::NotImplemented("RelationColumn GetArrayVector is abstract."); } @@ -49,7 +119,7 @@ class TypedRelationColumn> : public RelationColumn { public: using T = typename TypeTraits::CType; - TypedRelationColumn() {} + TypedRelationColumn() = default; bool IsNull(int array_id, int id) override { return (!has_null_) ? false : array_vector_[array_id]->IsNull(id); } @@ -67,13 +137,19 @@ class TypedRelationColumn> array_vector_.push_back(typed_in); return arrow::Status::OK(); } + arrow::Status ReleaseArray(int array_id) { + array_vector_[array_id] = nullptr; // fixme using reset()? + return arrow::Status::OK(); + } arrow::Status GetArrayVector(std::vector>* out) override { for (auto arr : array_vector_) { (*out).push_back(arr->cache_); } return arrow::Status::OK(); } - T GetValue(int array_id, int id) { return array_vector_[array_id]->GetView(id); } + virtual T GetValue(int array_id, int id) { + return array_vector_[array_id]->GetView(id); + } bool HasNull() { return has_null_; } private: @@ -104,13 +180,17 @@ class TypedRelationColumn> array_vector_.push_back(typed_in); return arrow::Status::OK(); } + arrow::Status ReleaseArray(int array_id) { + array_vector_[array_id] = nullptr; + return arrow::Status::OK(); + } arrow::Status GetArrayVector(std::vector>* out) override { for (auto arr : array_vector_) { (*out).push_back(arr->cache_); } return arrow::Status::OK(); } - std::string GetValue(int array_id, int id) { + virtual std::string GetValue(int array_id, int id) { return array_vector_[array_id]->GetString(id); } bool HasNull() { return has_null_; } @@ -122,3 +202,185 @@ class TypedRelationColumn> arrow::Status MakeRelationColumn(uint32_t data_type_id, std::shared_ptr* out); + +template +class TypedLazyLoadRelationColumn : public TypedRelationColumn {}; + +template +class TypedLazyLoadRelationColumn> + : public TypedRelationColumn { + public: + using T = typename TypeTraits::CType; + TypedLazyLoadRelationColumn() = default; + + bool IsNull(int array_id, int id) override { + AdvanceTo(array_id); + return delegated.IsNull(array_id, id); + } + + bool IsEqualTo(int x_array_id, int x_id, int y_array_id, int y_id) override { + AdvanceTo(x_array_id); + AdvanceTo(y_array_id); + return delegated.IsEqualTo(x_array_id, x_id, y_array_id, y_id); + } + + arrow::Status FromLazyBatchIterator(std::shared_ptr in, + int field_id) override { + in_ = in; + field_id_ = field_id; + return arrow::Status::OK(); + }; + + arrow::Status AdvanceTo(int array_id) override { + if (array_id < current_array_id_) { + return arrow::Status::OK(); + } + for (int i = current_array_id_; i <= array_id; i++) { + std::shared_ptr batch = in_->GetBatch(i); + std::shared_ptr array = batch->column(field_id_); + delegated.AppendColumn(array); + in_->RetainBatch(i); + array_released.push_back(false); + } + current_array_id_ = array_id + 1; + return arrow::Status::OK(); + } + + // return actual advanced array count + int32_t Advance(int32_t array_offset) override { + for (int i = 0; i < array_offset; i++) { + int target_batch = current_array_id_; + std::shared_ptr batch = in_->GetBatch(target_batch); + if (batch == nullptr) { + return i; + } + std::shared_ptr array = batch->column(field_id_); + delegated.AppendColumn(array); + in_->RetainBatch(target_batch); + array_released.push_back(false); + current_array_id_++; + } + return array_offset; + }; + + arrow::Status ReleaseArray(int array_id) override { + if (array_id >= array_released.size()) { + return arrow::Status::OK(); + } + if (array_released[array_id]) { + return arrow::Status::OK(); + } + RETURN_NOT_OK(delegated.ReleaseArray(array_id)); + in_->ReleaseBatch(array_id); + array_released[array_id] = true; + return arrow::Status::OK(); + } + + arrow::Status GetArrayVector(std::vector>* out) override { + return delegated.GetArrayVector(out); + } + + T GetValue(int array_id, int id) override { + AdvanceTo(array_id); + return delegated.GetValue(array_id, id); + } + bool HasNull() override { return has_null_; } + + private: + std::shared_ptr in_; + TypedRelationColumn delegated; + int current_array_id_ = 0; + int field_id_ = -1; + std::vector array_released; + bool has_null_ = true; // fixme always true +}; + +template +class TypedLazyLoadRelationColumn> + : public TypedRelationColumn { + public: + TypedLazyLoadRelationColumn() = default; + bool IsNull(int array_id, int id) override { + AdvanceTo(array_id); + return delegated.IsNull(array_id, id); + } + bool IsEqualTo(int x_array_id, int x_id, int y_array_id, int y_id) override { + AdvanceTo(x_array_id); + AdvanceTo(y_array_id); + return delegated.IsEqualTo(x_array_id, x_id, y_array_id, y_id); + } + + arrow::Status FromLazyBatchIterator(std::shared_ptr in, + int field_id) override { + in_ = in; + field_id_ = field_id; + return arrow::Status::OK(); + }; + + arrow::Status AdvanceTo(int array_id) override { + if (array_id < current_array_id_) { + return arrow::Status::OK(); + } + for (int i = current_array_id_; i <= array_id; i++) { + std::shared_ptr batch = in_->GetBatch(i); + std::shared_ptr array = batch->column(field_id_); + delegated.AppendColumn(array); + in_->RetainBatch(i); + array_released.push_back(false); + } + current_array_id_ = array_id + 1; + return arrow::Status::OK(); + } + + // return actual advanced array count + int32_t Advance(int32_t array_offset) override { + for (int i = 0; i < array_offset; i++) { + int target_batch = current_array_id_; + std::shared_ptr batch = in_->GetBatch(target_batch); + if (batch == nullptr) { + return i; + } + std::shared_ptr array = batch->column(field_id_); + delegated.AppendColumn(array); + in_->RetainBatch(target_batch); + array_released.push_back(false); + current_array_id_++; + } + return array_offset; + }; + + arrow::Status ReleaseArray(int array_id) override { + if (array_id >= array_released.size()) { + return arrow::Status::OK(); + } + if (array_released[array_id]) { + return arrow::Status::OK(); + } + RETURN_NOT_OK(delegated.ReleaseArray(array_id)); + in_->ReleaseBatch(array_id); + array_released[array_id] = true; + return arrow::Status::OK(); + } + + arrow::Status GetArrayVector(std::vector>* out) override { + return delegated.GetArrayVector(out); + } + + std::string GetValue(int array_id, int id) override { + AdvanceTo(array_id); + return delegated.GetValue(array_id, id); + } + + bool HasNull() override { return has_null_; } + + private: + std::shared_ptr in_; + TypedRelationColumn delegated; + int32_t current_array_id_ = 0; + int32_t field_id_ = -1; + std::vector array_released; + bool has_null_ = true; // fixme always true +}; + +arrow::Status MakeLazyLoadRelationColumn(uint32_t data_type_id, + std::shared_ptr* out); diff --git a/native-sql-engine/cpp/src/codegen/common/sort_relation.h b/native-sql-engine/cpp/src/codegen/common/sort_relation.h index f0caba42e..dc7e95355 100644 --- a/native-sql-engine/cpp/src/codegen/common/sort_relation.h +++ b/native-sql-engine/cpp/src/codegen/common/sort_relation.h @@ -36,57 +36,224 @@ using sparkcolumnarplugin::precompile::TypeTraits; class SortRelation { public: SortRelation( - arrow::compute::ExecContext* ctx, uint64_t items_total, - const std::vector& size_array, + arrow::compute::ExecContext* ctx, std::shared_ptr lazy_in, const std::vector>& sort_relation_key_list, - const std::vector>& sort_relation_payload_list) + const std::vector>& sort_relation_payload_list, + uint64_t items_total, const std::vector& size_array) : ctx_(ctx), items_total_(items_total) { + lazy_in_ = lazy_in; sort_relation_key_list_ = sort_relation_key_list; sort_relation_payload_list_ = sort_relation_payload_list; - int64_t buf_size = items_total_ * sizeof(ArrayItemIndexS); - auto maybe_buffer = arrow::AllocateBuffer(buf_size, ctx_->memory_pool()); - indices_buf_ = *std::move(maybe_buffer); - indices_begin_ = reinterpret_cast(indices_buf_->mutable_data()); - uint64_t idx = 0; - int array_id = 0; - for (auto size : size_array) { - for (int id = 0; id < size; id++) { - indices_begin_[idx].array_id = array_id; - indices_begin_[idx].id = id; - idx++; + + if (lazy_in_ != nullptr) { + is_lazy_input_ = true; + } + + if (!is_lazy_input_) { + int64_t buf_size = items_total_ * sizeof(ArrayItemIndexS); + auto maybe_buffer = arrow::AllocateBuffer(buf_size, ctx_->memory_pool()); + indices_buf_ = *std::move(maybe_buffer); + indices_begin_ = reinterpret_cast(indices_buf_->mutable_data()); + uint64_t idx = 0; + int array_id = 0; + for (auto size : size_array) { + for (int id = 0; id < size; id++) { + indices_begin_[idx].array_id = array_id; + indices_begin_[idx].id = id; + idx++; + } + array_id++; + } + } + } + + ~SortRelation() = default; + + static std::shared_ptr CreateLegacy( + arrow::compute::ExecContext* ctx, + const std::vector>& sort_relation_key_list, + const std::vector>& sort_relation_payload_list, + uint64_t items_total, const std::vector& size_array) { + return std::make_shared(ctx, nullptr, sort_relation_key_list, + sort_relation_payload_list, items_total, + size_array); + } + + static std::shared_ptr CreateLazy( + arrow::compute::ExecContext* ctx, std::shared_ptr lazy_in, + const std::vector>& sort_relation_key_list, + const std::vector>& sort_relation_payload_list) { + return std::make_shared(ctx, lazy_in, sort_relation_key_list, + sort_relation_payload_list, -1L, + std::vector()); + } + + void ArrayRelease(int array_id) { + for (auto col : sort_relation_key_list_) { + col->ReleaseArray(array_id); + } + for (auto col : sort_relation_payload_list_) { + col->ReleaseArray(array_id); + } + } + + int32_t ArrayAdvance(int32_t array_offset) { + int32_t result = -1; + for (auto col : sort_relation_payload_list_) { + int32_t granted = col->Advance(array_offset); + if (result == -1) { + result = granted; + continue; + } + if (granted != result) { + return -1; // error } - array_id++; } + return result; + } - std::shared_ptr out_type; + void ArrayAdvanceTo(int array_id) { + for (auto col : sort_relation_key_list_) { + col->Advance(array_id); + } + for (auto col : sort_relation_payload_list_) { + col->Advance(array_id); + } } - ~SortRelation() {} + void Advance(int shift) { + int64_t batch_length = lazy_in_->GetNumRowsOfBatch(requested_batches); + int64_t batch_remaining = (batch_length - 1) - offset_in_current_batch_; + if (shift <= batch_remaining) { + offset_in_current_batch_ = offset_in_current_batch_ + shift; + return; + } + int64_t remaining = shift - batch_remaining; + int32_t batch_i = requested_batches + 1; + while (true) { + int64_t current_batch_length = lazy_in_->GetNumRowsOfBatch(batch_i); + if (remaining <= current_batch_length) { + requested_batches = batch_i; + ArrayAdvanceTo(requested_batches); + for (int32_t i = 0; i < requested_batches; i++) { + ArrayRelease(i); + } + offset_in_current_batch_ = remaining - 1; + return; + } + remaining -= current_batch_length; + batch_i++; + } + } ArrayItemIndexS GetItemIndexWithShift(int shift) { - return indices_begin_[offset_ + shift]; + if (!is_lazy_input_) { + return indices_begin_[offset_ + shift]; + } + int64_t batch_length = lazy_in_->GetNumRowsOfBatch(requested_batches); + int64_t batch_remaining = (batch_length - 1) - offset_in_current_batch_; + if (shift <= batch_remaining) { + ArrayItemIndexS s(requested_batches, offset_in_current_batch_ + shift); + return s; + } + int64_t remaining = shift - batch_remaining; + int32_t batch_i = requested_batches + 1; + while (true) { + int64_t current_batch_length = lazy_in_->GetNumRowsOfBatch(batch_i); + if (remaining <= current_batch_length) { + ArrayItemIndexS s(batch_i, remaining - 1); + return s; + } + remaining -= current_batch_length; + batch_i++; + } + } + + bool CheckRangeBound(int shift) { + if (!is_lazy_input_) { + return offset_ + shift < items_total_; + } + int64_t batch_length = lazy_in_->GetNumRowsOfBatch(requested_batches); + if (batch_length == -1L) { + return false; + } + int64_t batch_remaining = (batch_length - 1) - offset_in_current_batch_; + if (shift <= batch_remaining) { + return true; + } + int64_t remaining = shift - batch_remaining; + int32_t batch_i = requested_batches + 1; + while (remaining >= 0) { + int64_t current_batch_length = lazy_in_->GetNumRowsOfBatch(batch_i); + if (current_batch_length == -1L) { + return false; + } + remaining -= current_batch_length; + batch_i++; + } + return true; } + // IS THIS POSSIBLY BUGGY AS THE FIRST ELEMENT DID NOT GET CHECKED? bool Next() { + if (!is_lazy_input_) { + if (!CheckRangeBound(1)) return false; + offset_++; + range_cache_ = -1; + return true; + } if (!CheckRangeBound(1)) return false; + Advance(1); offset_++; range_cache_ = -1; return true; } bool NextNewKey() { + if (!is_lazy_input_) { + auto range = GetSameKeyRange(); + if (!CheckRangeBound(range)) return false; + offset_ += range; + range_cache_ = -1; + return true; + } auto range = GetSameKeyRange(); if (!CheckRangeBound(range)) return false; + Advance(range); offset_ += range; range_cache_ = -1; - return true; } int GetSameKeyRange() { + if (!is_lazy_input_) { + if (range_cache_ != -1) return range_cache_; + int range = 0; + if (!CheckRangeBound(range)) return range; + bool is_same = true; + while (is_same) { + if (CheckRangeBound(range + 1)) { + auto cur_idx = GetItemIndexWithShift(range); + auto cur_idx_plus_one = GetItemIndexWithShift(range + 1); + for (auto col : sort_relation_key_list_) { + if (!(is_same = + col->IsEqualTo(cur_idx.array_id, cur_idx.id, + cur_idx_plus_one.array_id, cur_idx_plus_one.id))) + break; + } + } else { + is_same = false; + } + if (!is_same) break; + range++; + } + range += 1; + range_cache_ = range; + return range; + } if (range_cache_ != -1) return range_cache_; + if (!CheckRangeBound(0)) return 0; int range = 0; - if (!CheckRangeBound(range)) return range; bool is_same = true; while (is_same) { if (CheckRangeBound(range + 1)) { @@ -108,8 +275,6 @@ class SortRelation { return range; } - bool CheckRangeBound(int shift) { return offset_ + shift < items_total_; } - template arrow::Status GetColumn(int idx, std::shared_ptr* out) { *out = std::dynamic_pointer_cast(sort_relation_payload_list_[idx]); @@ -118,11 +283,18 @@ class SortRelation { protected: arrow::compute::ExecContext* ctx_; - std::shared_ptr indices_buf_; - ArrayItemIndexS* indices_begin_; - const uint64_t items_total_; + std::shared_ptr lazy_in_; uint64_t offset_ = 0; + int64_t offset_in_current_batch_ = 0; + int32_t requested_batches = 0; int range_cache_ = -1; std::vector> sort_relation_key_list_; std::vector> sort_relation_payload_list_; + + // required by legacy method + bool is_lazy_input_ = false; + + std::shared_ptr indices_buf_; + ArrayItemIndexS* indices_begin_; + const uint64_t items_total_; }; diff --git a/native-sql-engine/cpp/src/codegen/compute_ext/code_generator.h b/native-sql-engine/cpp/src/codegen/compute_ext/code_generator.h index cf18d4d32..20b3a6d55 100644 --- a/native-sql-engine/cpp/src/codegen/compute_ext/code_generator.h +++ b/native-sql-engine/cpp/src/codegen/compute_ext/code_generator.h @@ -51,6 +51,8 @@ class ComputeExtCodeGenerator : public CodeGenerator { return status; } + arrow::Status evaluate(arrow::RecordBatchIterator in) { return arrow::Status::OK(); } + arrow::Status finish(std::vector>* out) { return arrow::Status::OK(); } diff --git a/native-sql-engine/cpp/src/codegen/gandiva/code_generator.h b/native-sql-engine/cpp/src/codegen/gandiva/code_generator.h index 64b718e8f..e3f0170ac 100644 --- a/native-sql-engine/cpp/src/codegen/gandiva/code_generator.h +++ b/native-sql-engine/cpp/src/codegen/gandiva/code_generator.h @@ -46,6 +46,7 @@ class GandivaCodeGenerator : public CodeGenerator { std::vector>* out) { return arrow::Status::OK(); } + arrow::Status evaluate(arrow::RecordBatchIterator in) { return arrow::Status::OK(); } arrow::Status finish(std::vector>* out) { return arrow::Status::OK(); } diff --git a/native-sql-engine/cpp/src/jni/jni_wrapper.cc b/native-sql-engine/cpp/src/jni/jni_wrapper.cc index 7351e581d..f0582bfb2 100644 --- a/native-sql-engine/cpp/src/jni/jni_wrapper.cc +++ b/native-sql-engine/cpp/src/jni/jni_wrapper.cc @@ -26,12 +26,14 @@ #include #include #include +#include #include #include #include #include #include +#include #include "codegen/code_generator_factory.h" #include "codegen/common/hash_relation.h" @@ -61,12 +63,16 @@ static jmethodID serializable_obj_builder_constructor; static jclass split_result_class; static jmethodID split_result_constructor; +jclass serialized_record_batch_iterator_class; static jclass metrics_builder_class; static jmethodID metrics_builder_constructor; static jclass unsafe_row_class; static jmethodID unsafe_row_class_constructor; static jmethodID unsafe_row_class_point_to; +jmethodID serialized_record_batch_iterator_hasNext; +jmethodID serialized_record_batch_iterator_next; + using arrow::jni::ConcurrentMap; static ConcurrentMap> buffer_holder_; @@ -111,6 +117,40 @@ std::shared_ptr> GetBatchIterator(JNIEnv* env, jlong id) { return std::dynamic_pointer_cast>(handler); } +arrow::Result> FromBytes( + JNIEnv* env, std::shared_ptr schema, jbyteArray bytes) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr batch, + arrow::jniutil::DeserializeUnsafeFromJava(env, schema, bytes)) + return batch; +} + +// See Java class +// org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator +// +arrow::Result MakeJavaRecordBatchIterator( + JavaVM* vm, jobject java_serialized_record_batch_iterator, + std::shared_ptr schema) { + std::shared_ptr schema_moved = std::move(schema); + arrow::RecordBatchIterator itr = arrow::MakeFunctionIterator( + [vm, java_serialized_record_batch_iterator, + schema_moved]() -> arrow::Result> { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return arrow::Status::Invalid("JNIEnv was not attached to current thread"); + } + if (!env->CallBooleanMethod(java_serialized_record_batch_iterator, + serialized_record_batch_iterator_hasNext)) { + return nullptr; // stream ended + } + auto bytes = (jbyteArray)env->CallObjectMethod( + java_serialized_record_batch_iterator, serialized_record_batch_iterator_next); + RETURN_NOT_OK(arrow::jniutil::CheckException(env)); + ARROW_ASSIGN_OR_RAISE(auto batch, FromBytes(env, schema_moved, bytes)); + return batch; + }); + return itr; +} + jobject MakeRecordBatchBuilder(JNIEnv* env, std::shared_ptr schema, std::shared_ptr record_batch) { jobjectArray field_array = @@ -212,6 +252,15 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { env, "Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;"); unsafe_row_class_constructor = GetMethodID(env, unsafe_row_class, "", "(I)V"); unsafe_row_class_point_to = GetMethodID(env, unsafe_row_class, "pointTo", "([BI)V"); + serialized_record_batch_iterator_class = + CreateGlobalClassReference(env, + "Lorg/apache/arrow/" + "dataset/jni/NativeSerializedRecordBatchIterator;"); + serialized_record_batch_iterator_hasNext = + GetMethodID(env, serialized_record_batch_iterator_class, "hasNext", "()Z"); + serialized_record_batch_iterator_next = + GetMethodID(env, serialized_record_batch_iterator_class, "next", "()[B"); + return JNI_VERSION; } @@ -231,6 +280,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(serializable_obj_builder_class); env->DeleteGlobalRef(split_result_class); env->DeleteGlobalRef(unsafe_row_class); + env->DeleteGlobalRef(serialized_record_batch_iterator_class); buffer_holder_.Clear(); handler_holder_.Clear(); @@ -489,6 +539,40 @@ Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeEvaluate( return record_batch_builder_array; } +JNIEXPORT void JNICALL +Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeEvaluateWithIterator( + JNIEnv* env, jobject obj, jlong id, jobject itr) { + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + std::string error_message = "Unable to get JavaVM instance"; + env->ThrowNew(io_exception_class, error_message.c_str()); + } + arrow::Status status; + std::shared_ptr handler = GetCodeGenerator(env, id); + std::shared_ptr schema; + status = handler->getSchema(&schema); + + // IMPORTANT: DO NOT USE LOCAL REF IN DIFFERENT THREAD + // TODO Release this in JNI Unload or dependent object's destructor + jobject itr2 = env->NewGlobalRef(itr); + arrow::Result rb_itr_status = + MakeJavaRecordBatchIterator(vm, itr2, schema); + + if (!rb_itr_status.ok()) { + std::string error_message = + "nativeEvaluate: error making java iterator" + rb_itr_status.status().ToString(); + env->ThrowNew(io_exception_class, error_message.c_str()); + } + + status = handler->evaluate(std::move(rb_itr_status.ValueOrDie())); + + if (!status.ok()) { + std::string error_message = + "nativeEvaluate: evaluate failed with error msg " + status.ToString(); + env->ThrowNew(io_exception_class, error_message.c_str()); + } +} + JNIEXPORT jstring JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeGetSignature( JNIEnv* env, jobject obj, jlong id) { diff --git a/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc b/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc index 04b8eccaf..d62959637 100644 --- a/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc +++ b/native-sql-engine/cpp/src/tests/arrow_compute_test_wscg.cc @@ -4151,10 +4151,24 @@ TEST(TestArrowComputeWSCG, WSCGTestContinuousMergeJoinSemiExistenceWithCondition ASSERT_NOT_OK(expr->finish(&build_result_iterator)); auto rb_iter = std::dynamic_pointer_cast>( build_result_iterator); - while (rb_iter->HasNext()) { - std::shared_ptr result_batch; - ASSERT_NOT_OK(rb_iter->Next(&result_batch)); - ASSERT_NOT_OK(cache->evaluate(result_batch, &dummy_result_batches)); + if (false) { + // lazy read + arrow::RecordBatchIterator itr = arrow::MakeFunctionIterator( + [rb_iter]() -> arrow::Result> { + if (!rb_iter->HasNext()) { + return nullptr; + } + std::shared_ptr batch; + ASSERT_NOT_OK(rb_iter->Next(&batch)); + return batch; + }); + ASSERT_NOT_OK(cache->evaluate(std::move(itr))); + } else { + while (rb_iter->HasNext()) { + std::shared_ptr result_batch; + ASSERT_NOT_OK(rb_iter->Next(&result_batch)); + ASSERT_NOT_OK(cache->evaluate(result_batch, &dummy_result_batches)); + } } ASSERT_NOT_OK(cache->finish(&build_result_iterator)); dependency_iterator_list.push_back(build_result_iterator); @@ -4998,10 +5012,24 @@ TEST(TestArrowComputeWSCG, WSCGTestStringMergeInnerJoinWithGroupbyAggregate) { ASSERT_NOT_OK(expr->finish(&build_result_iterator)); auto rb_iter = std::dynamic_pointer_cast>( build_result_iterator); - while (rb_iter->HasNext()) { - std::shared_ptr result_batch; - ASSERT_NOT_OK(rb_iter->Next(&result_batch)); - ASSERT_NOT_OK(cache->evaluate(result_batch, &dummy_result_batches)); + if (false) { + // lazy read + arrow::RecordBatchIterator itr = arrow::MakeFunctionIterator( + [rb_iter]() -> arrow::Result> { + if (!rb_iter->HasNext()) { + return nullptr; + } + std::shared_ptr batch; + ASSERT_NOT_OK(rb_iter->Next(&batch)); + return batch; + }); + ASSERT_NOT_OK(cache->evaluate(std::move(itr))); + } else { + while (rb_iter->HasNext()) { + std::shared_ptr result_batch; + ASSERT_NOT_OK(rb_iter->Next(&result_batch)); + ASSERT_NOT_OK(cache->evaluate(result_batch, &dummy_result_batches)); + } } ASSERT_NOT_OK(cache->finish(&build_result_iterator)); dependency_iterator_list.push_back(build_result_iterator); From fb31888e8bc65b0937cb545a7d8f237a34685ba3 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 23 Jul 2021 16:58:50 +0800 Subject: [PATCH 2/3] rebase --- .../cpp/src/codegen/arrow_compute/ext/kernels_ext.cc | 4 ++-- native-sql-engine/cpp/src/jni/jni_wrapper.cc | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc index 5e3eb53a5..b8c11383a 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc @@ -339,7 +339,7 @@ class CachedRelationKernel::Impl { col_num_ = result_schema->num_fields(); } - arrow::Status Evaluate(ArrayList& in) { + arrow::Status Evaluate(const ArrayList& in) { items_total_ += in[0]->length(); length_list_.push_back(in[0]->length()); if (cached_.size() < col_num_) { @@ -453,7 +453,7 @@ CachedRelationKernel::CachedRelationKernel( kernel_name_ = "CachedRelationKernel"; } -arrow::Status CachedRelationKernel::Evaluate(ArrayList& in) { +arrow::Status CachedRelationKernel::Evaluate(const ArrayList& in) { return impl_->Evaluate(in); } diff --git a/native-sql-engine/cpp/src/jni/jni_wrapper.cc b/native-sql-engine/cpp/src/jni/jni_wrapper.cc index f0582bfb2..1c4524348 100644 --- a/native-sql-engine/cpp/src/jni/jni_wrapper.cc +++ b/native-sql-engine/cpp/src/jni/jni_wrapper.cc @@ -1094,13 +1094,6 @@ Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_nativeSpill( return spilled_size; } -arrow::Result> FromBytes( - JNIEnv* env, std::shared_ptr schema, jbyteArray bytes) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr batch, - arrow::jniutil::DeserializeUnsafeFromJava(env, schema, bytes)) - return batch; -} - JNIEXPORT jobject JNICALL Java_com_intel_oap_vectorized_ExpressionEvaluatorJniWrapper_nativeEvaluate2( JNIEnv* env, jobject obj, jlong id, jbyteArray bytes) { From 9cb0ab8870a4f11bf1b8c939c4c4a609880dd111 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 23 Jul 2021 17:24:17 +0800 Subject: [PATCH 3/3] rebase --- .../cpp/src/codegen/arrow_compute/expr_visitor.h | 1 - .../cpp/src/codegen/arrow_compute/ext/kernels_ext.cc | 4 ++-- .../cpp/src/codegen/arrow_compute/ext/kernels_ext.h | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h index 9a48fe1a9..aa01eca52 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.h @@ -155,7 +155,6 @@ class ExprVisitor : public std::enable_shared_from_this { arrow::Status Eval(const std::shared_ptr& selection_in, const std::shared_ptr& in); arrow::Status Eval(std::shared_ptr& in); - arrow::Status Eval(const std::shared_ptr& in); arrow::Status Eval(arrow::RecordBatchIterator in); arrow::Status Eval(); std::string GetSignature() { return signature_; } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc index b8c11383a..5e3eb53a5 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.cc @@ -339,7 +339,7 @@ class CachedRelationKernel::Impl { col_num_ = result_schema->num_fields(); } - arrow::Status Evaluate(const ArrayList& in) { + arrow::Status Evaluate(ArrayList& in) { items_total_ += in[0]->length(); length_list_.push_back(in[0]->length()); if (cached_.size() < col_num_) { @@ -453,7 +453,7 @@ CachedRelationKernel::CachedRelationKernel( kernel_name_ = "CachedRelationKernel"; } -arrow::Status CachedRelationKernel::Evaluate(const ArrayList& in) { +arrow::Status CachedRelationKernel::Evaluate(ArrayList& in) { return impl_->Evaluate(in); } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h index 90931e253..d96842aeb 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h @@ -237,7 +237,7 @@ class CachedRelationKernel : public KernalBase { std::shared_ptr result_schema, std::vector> key_field_list, int result_type); - arrow::Status Evaluate(const ArrayList& in) override; + arrow::Status Evaluate(ArrayList& in) override; arrow::Status Evaluate(arrow::RecordBatchIterator in) override; arrow::Status MakeResultIterator( std::shared_ptr schema,