From 1502609085956ffd80aee5244d34faaafecd2313 Mon Sep 17 00:00:00 2001 From: Francis <455954986@qq.com> Date: Sat, 14 Oct 2023 01:57:49 +0800 Subject: [PATCH] GH-37895: [C++] Feature: support concatenate recordbatches. (#37896) ### Rationale for this change User scenario: When we use acero plan, many smaller batches may be generated through agg and hashjoin. In addition, due to the mpp database, there is data distribution. When there are many segments, each segment data is compared at this time. Small, in order to improve performance, we hope to merge multiple fragmented small batches into one large batch for calculation together. ### What changes are included in this PR? record_batch.cc record_batch.h record_batch_test.cc ### Are these changes tested? yes, see record_batch_test.cc ### Are there any user-facing changes? yes * Closes: #37895 Authored-by: light-city <455954986@qq.com> Signed-off-by: Benjamin Kietzman --- cpp/src/arrow/record_batch.cc | 33 ++++++++++++++++++++++++++ cpp/src/arrow/record_batch.h | 14 +++++++++++ cpp/src/arrow/record_batch_test.cc | 37 ++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index f0ee295c6347d..457135fa400d5 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -25,6 +25,7 @@ #include #include "arrow/array.h" +#include "arrow/array/concatenate.h" #include "arrow/array/validate.h" #include "arrow/pretty_print.h" #include "arrow/status.h" @@ -432,4 +433,36 @@ RecordBatchReader::~RecordBatchReader() { ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close failed"); } +Result> ConcatenateRecordBatches( + const RecordBatchVector& batches, MemoryPool* pool) { + int64_t length = 0; + size_t n = batches.size(); + if (n == 0) { + return Status::Invalid("Must pass at least one recordbatch"); + } + int cols = batches[0]->num_columns(); + auto schema = batches[0]->schema(); + for (size_t i = 0; i < batches.size(); ++i) { + length += batches[i]->num_rows(); + if (!schema->Equals(batches[i]->schema())) { + return Status::Invalid( + "Schema of RecordBatch index ", i, " is ", batches[i]->schema()->ToString(), + ", which does not match index 0 recordbatch schema: ", schema->ToString()); + } + } + + std::vector> concatenated_columns; + concatenated_columns.reserve(cols); + for (int col = 0; col < cols; ++col) { + ArrayVector column_arrays; + column_arrays.reserve(batches.size()); + for (const auto& batch : batches) { + column_arrays.emplace_back(batch->column(col)); + } + ARROW_ASSIGN_OR_RAISE(auto concatenated_column, Concatenate(column_arrays, pool)) + concatenated_columns.emplace_back(std::move(concatenated_column)); + } + return RecordBatch::Make(std::move(schema), length, std::move(concatenated_columns)); +} + } // namespace arrow diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index cb1f6d54f7cff..1a66fc3fb5629 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -350,4 +350,18 @@ class ARROW_EXPORT RecordBatchReader { Iterator> batches, std::shared_ptr schema); }; +/// \brief Concatenate record batches +/// +/// The columns of the new batch are formed by concatenate the same columns of each input +/// batch. Concatenate multiple batches into a new batch requires that the schema must be +/// consistent. It supports merging batches without columns (only length, scenarios such +/// as count(*)). +/// +/// \param[in] batches a vector of record batches to be concatenated +/// \param[in] pool memory to store the result will be allocated from this memory pool +/// \return the concatenated record batch +ARROW_EXPORT +Result> ConcatenateRecordBatches( + const RecordBatchVector& batches, MemoryPool* pool = default_memory_pool()); + } // namespace arrow diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index bc923a1444160..db3a2d3def73f 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -555,4 +555,41 @@ TEST_F(TestRecordBatch, ReplaceSchema) { ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema)); } +TEST_F(TestRecordBatch, ConcatenateRecordBatches) { + int length = 10; + + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8()); + + auto schema = ::arrow::schema({f0, f1}); + + random::RandomArrayGenerator gen(42); + + auto b1 = gen.BatchOf(schema->fields(), length); + + length = 5; + + auto b2 = gen.BatchOf(schema->fields(), length); + + ASSERT_OK_AND_ASSIGN(auto batch, ConcatenateRecordBatches({b1, b2})); + ASSERT_EQ(batch->num_rows(), b1->num_rows() + b2->num_rows()); + ASSERT_BATCHES_EQUAL(*batch->Slice(0, b1->num_rows()), *b1); + ASSERT_BATCHES_EQUAL(*batch->Slice(b1->num_rows()), *b2); + + f0 = field("fd0", int32()); + f1 = field("fd1", uint8()); + + schema = ::arrow::schema({f0, f1}); + + auto b3 = gen.BatchOf(schema->fields(), length); + + ASSERT_RAISES(Invalid, ConcatenateRecordBatches({b1, b3})); + + auto null_batch = RecordBatch::Make(::arrow::schema({}), length, + std::vector>{}); + ASSERT_OK_AND_ASSIGN(batch, ConcatenateRecordBatches({null_batch})); + ASSERT_EQ(batch->num_rows(), null_batch->num_rows()); + ASSERT_BATCHES_EQUAL(*batch, *null_batch); +} + } // namespace arrow