Skip to content

Commit

Permalink
apacheGH-37895: [C++] Feature: support concatenate recordbatches. (ap…
Browse files Browse the repository at this point in the history
…ache#37896)

### Rationale for this change

User scenario: When we use acero plan, many smaller batches may be generated through agg and hashjoin. In addition, due to the mpp database, there is data distribution. When there are many segments, each segment data is compared at this time. Small, in order to improve performance, we hope to merge multiple fragmented small batches into one large batch for calculation together.

### What changes are included in this PR?

record_batch.cc
record_batch.h
record_batch_test.cc

### Are these changes tested?

yes, see record_batch_test.cc

### Are there any user-facing changes?

yes

* Closes: apache#37895

Authored-by: light-city <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
  • Loading branch information
Light-City authored and loicalleyne committed Nov 13, 2023
1 parent 18cd817 commit 1502609
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
33 changes: 33 additions & 0 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <utility>

#include "arrow/array.h"
#include "arrow/array/concatenate.h"
#include "arrow/array/validate.h"
#include "arrow/pretty_print.h"
#include "arrow/status.h"
Expand Down Expand Up @@ -432,4 +433,36 @@ RecordBatchReader::~RecordBatchReader() {
ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close failed");
}

Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
const RecordBatchVector& batches, MemoryPool* pool) {
int64_t length = 0;
size_t n = batches.size();
if (n == 0) {
return Status::Invalid("Must pass at least one recordbatch");
}
int cols = batches[0]->num_columns();
auto schema = batches[0]->schema();
for (size_t i = 0; i < batches.size(); ++i) {
length += batches[i]->num_rows();
if (!schema->Equals(batches[i]->schema())) {
return Status::Invalid(
"Schema of RecordBatch index ", i, " is ", batches[i]->schema()->ToString(),
", which does not match index 0 recordbatch schema: ", schema->ToString());
}
}

std::vector<std::shared_ptr<Array>> concatenated_columns;
concatenated_columns.reserve(cols);
for (int col = 0; col < cols; ++col) {
ArrayVector column_arrays;
column_arrays.reserve(batches.size());
for (const auto& batch : batches) {
column_arrays.emplace_back(batch->column(col));
}
ARROW_ASSIGN_OR_RAISE(auto concatenated_column, Concatenate(column_arrays, pool))
concatenated_columns.emplace_back(std::move(concatenated_column));
}
return RecordBatch::Make(std::move(schema), length, std::move(concatenated_columns));
}

} // namespace arrow
14 changes: 14 additions & 0 deletions cpp/src/arrow/record_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,18 @@ class ARROW_EXPORT RecordBatchReader {
Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema> schema);
};

/// \brief Concatenate record batches
///
/// The columns of the new batch are formed by concatenate the same columns of each input
/// batch. Concatenate multiple batches into a new batch requires that the schema must be
/// consistent. It supports merging batches without columns (only length, scenarios such
/// as count(*)).
///
/// \param[in] batches a vector of record batches to be concatenated
/// \param[in] pool memory to store the result will be allocated from this memory pool
/// \return the concatenated record batch
ARROW_EXPORT
Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
const RecordBatchVector& batches, MemoryPool* pool = default_memory_pool());

} // namespace arrow
37 changes: 37 additions & 0 deletions cpp/src/arrow/record_batch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,41 @@ TEST_F(TestRecordBatch, ReplaceSchema) {
ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema));
}

TEST_F(TestRecordBatch, ConcatenateRecordBatches) {
int length = 10;

auto f0 = field("f0", int32());
auto f1 = field("f1", uint8());

auto schema = ::arrow::schema({f0, f1});

random::RandomArrayGenerator gen(42);

auto b1 = gen.BatchOf(schema->fields(), length);

length = 5;

auto b2 = gen.BatchOf(schema->fields(), length);

ASSERT_OK_AND_ASSIGN(auto batch, ConcatenateRecordBatches({b1, b2}));
ASSERT_EQ(batch->num_rows(), b1->num_rows() + b2->num_rows());
ASSERT_BATCHES_EQUAL(*batch->Slice(0, b1->num_rows()), *b1);
ASSERT_BATCHES_EQUAL(*batch->Slice(b1->num_rows()), *b2);

f0 = field("fd0", int32());
f1 = field("fd1", uint8());

schema = ::arrow::schema({f0, f1});

auto b3 = gen.BatchOf(schema->fields(), length);

ASSERT_RAISES(Invalid, ConcatenateRecordBatches({b1, b3}));

auto null_batch = RecordBatch::Make(::arrow::schema({}), length,
std::vector<std::shared_ptr<ArrayData>>{});
ASSERT_OK_AND_ASSIGN(batch, ConcatenateRecordBatches({null_batch}));
ASSERT_EQ(batch->num_rows(), null_batch->num_rows());
ASSERT_BATCHES_EQUAL(*batch, *null_batch);
}

} // namespace arrow

0 comments on commit 1502609

Please sign in to comment.