Skip to content

Commit

Permalink
Respond to Weston comments
Browse files Browse the repository at this point in the history
  • Loading branch information
save-buffer committed Jan 12, 2023
1 parent 12f3b5b commit 5cb8c50
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 191 deletions.
93 changes: 48 additions & 45 deletions cpp/src/arrow/compute/exec/accumulation_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,16 @@ Status SpillingAccumulationQueue::Init(QueryContext* ctx) {
ctx_ = ctx;
partition_locks_.Init(ctx_->max_concurrency(), kNumPartitions);
for (size_t ipart = 0; ipart < kNumPartitions; ipart++) {
task_group_read_[ipart] = ctx_->RegisterTaskGroup(
Partition& part = partitions_[ipart];
part.task_group_read = ctx_->RegisterTaskGroup(
[this, ipart](size_t thread_index, int64_t batch_index) {
return read_back_fn_[ipart](thread_index, static_cast<size_t>(batch_index),
std::move(queues_[ipart][batch_index]));
return partitions_[ipart].read_back_fn(
thread_index, static_cast<size_t>(batch_index),
std::move(partitions_[ipart].queue[batch_index]));
},
[this, ipart](size_t thread_index) { return on_finished_[ipart](thread_index); });
[this, ipart](size_t thread_index) {
return partitions_[ipart].on_finished(thread_index);
});
}
return Status::OK();
}
Expand All @@ -89,37 +93,39 @@ Status SpillingAccumulationQueue::InsertBatch(size_t thread_index, ExecBatch bat
int unprocessed_partition_ids[kNumPartitions];
RETURN_NOT_OK(partition_locks_.ForEachPartition(
thread_index, unprocessed_partition_ids,
/*is_prtn_empty=*/
/*is_prtn_empty_fn=*/
[&](int part_id) { return part_starts[part_id + 1] == part_starts[part_id]; },
/*partition=*/
/*process_prtn_fn=*/
[&](int locked_part_id_int) {
size_t locked_part_id = static_cast<size_t>(locked_part_id_int);
uint64_t num_total_rows_to_append =
part_starts[locked_part_id + 1] - part_starts[locked_part_id];

Partition& locked_part = partitions_[locked_part_id];

size_t offset = static_cast<size_t>(part_starts[locked_part_id]);
while (num_total_rows_to_append > 0) {
int num_rows_to_append =
std::min(static_cast<int>(num_total_rows_to_append),
static_cast<int>(ExecBatchBuilder::num_rows_max() -
builders_[locked_part_id].num_rows()));
locked_part.builder.num_rows()));

RETURN_NOT_OK(builders_[locked_part_id].AppendSelected(
RETURN_NOT_OK(locked_part.builder.AppendSelected(
ctx_->memory_pool(), batch, num_rows_to_append, permutation.data() + offset,
batch.num_values()));

if (builders_[locked_part_id].is_full()) {
ExecBatch batch = builders_[locked_part_id].Flush();
if (locked_part.builder.is_full()) {
ExecBatch batch = locked_part.builder.Flush();
Datum hash = std::move(batch.values.back());
batch.values.pop_back();
ExecBatch hash_batch({std::move(hash)}, batch.length);
if (locked_part_id < spilling_cursor_)
RETURN_NOT_OK(files_[locked_part_id].SpillBatch(ctx_, std::move(batch)));
RETURN_NOT_OK(locked_part.file.SpillBatch(ctx_, std::move(batch)));
else
queues_[locked_part_id].InsertBatch(std::move(batch));
locked_part.queue.InsertBatch(std::move(batch));

if (locked_part_id >= hash_cursor_)
hash_queues_[locked_part_id].InsertBatch(std::move(hash_batch));
locked_part.hash_queue.InsertBatch(std::move(hash_batch));
}
offset += num_rows_to_append;
num_total_rows_to_append -= num_rows_to_append;
Expand All @@ -129,56 +135,52 @@ Status SpillingAccumulationQueue::InsertBatch(size_t thread_index, ExecBatch bat
return Status::OK();
}

const uint64_t* SpillingAccumulationQueue::GetHashes(size_t partition, size_t batch_idx) {
ARROW_DCHECK(partition >= hash_cursor_.load());
if (batch_idx > hash_queues_[partition].batch_count()) {
const Datum& datum = hash_queues_[partition][batch_idx].values[0];
const uint64_t* SpillingAccumulationQueue::GetHashes(size_t partition_idx,
size_t batch_idx) {
ARROW_DCHECK(partition_idx >= hash_cursor_.load());
Partition& partition = partitions_[partition_idx];
if (batch_idx > partition.hash_queue.batch_count()) {
const Datum& datum = partition.hash_queue[batch_idx].values[0];
return reinterpret_cast<const uint64_t*>(datum.array()->buffers[1]->data());
} else {
size_t hash_idx = builders_[partition].num_cols();
KeyColumnArray kca = builders_[partition].column(hash_idx - 1);
size_t hash_idx = partition.builder.num_cols();
KeyColumnArray kca = partition.builder.column(hash_idx - 1);
return reinterpret_cast<const uint64_t*>(kca.data(1));
}
}

Status SpillingAccumulationQueue::GetPartition(
size_t thread_index, size_t partition,
size_t thread_index, size_t partition_idx,
std::function<Status(size_t, size_t, ExecBatch)> on_batch,
std::function<Status(size_t)> on_finished) {
bool is_in_memory = partition >= spilling_cursor_.load();
if (builders_[partition].num_rows() > 0) {
ExecBatch batch = builders_[partition].Flush();
Datum hash = std::move(batch.values.back());
bool is_in_memory = partition_idx >= spilling_cursor_.load();
Partition& partition = partitions_[partition_idx];
if (partition.builder.num_rows() > 0) {
ExecBatch batch = partition.builder.Flush();
batch.values.pop_back();
if (is_in_memory) {
ExecBatch hash_batch({std::move(hash)}, batch.length);
hash_queues_[partition].InsertBatch(std::move(hash_batch));
queues_[partition].InsertBatch(std::move(batch));
} else {
RETURN_NOT_OK(on_batch(thread_index,
/*batch_index=*/queues_[partition].batch_count(),
std::move(batch)));
}
RETURN_NOT_OK(on_batch(thread_index,
/*batch_index=*/partition.queue.batch_count(),
std::move(batch)));
}

if (is_in_memory) {
ARROW_DCHECK(partition >= hash_cursor_.load());
read_back_fn_[partition] = std::move(on_batch);
on_finished_[partition] = std::move(on_finished);
return ctx_->StartTaskGroup(task_group_read_[partition],
queues_[partition].batch_count());
ARROW_DCHECK(partition_idx >= hash_cursor_.load());
partition.read_back_fn = std::move(on_batch);
partition.on_finished = std::move(on_finished);
return ctx_->StartTaskGroup(partition.task_group_read, partition.queue.batch_count());
}

return files_[partition].ReadBackBatches(
return partition.file.ReadBackBatches(
ctx_, on_batch,
[this, partition, finished = std::move(on_finished)](size_t thread_index) {
RETURN_NOT_OK(files_[partition].Cleanup());
[this, partition_idx, finished = std::move(on_finished)](size_t thread_index) {
RETURN_NOT_OK(partitions_[partition_idx].file.Cleanup());
return finished(thread_index);
});
}

size_t SpillingAccumulationQueue::CalculatePartitionRowCount(size_t partition) const {
return builders_[partition].num_rows() + queues_[partition].CalculateRowCount();
return partitions_[partition].builder.num_rows() +
partitions_[partition].queue.CalculateRowCount();
}

Result<bool> SpillingAccumulationQueue::AdvanceSpillCursor() {
Expand All @@ -191,9 +193,10 @@ Result<bool> SpillingAccumulationQueue::AdvanceSpillCursor() {
}

auto lock = partition_locks_.AcquirePartitionLock(static_cast<int>(to_spill));
size_t num_batches = queues_[to_spill].batch_count();
Partition& partition = partitions_[to_spill];
size_t num_batches = partition.queue.batch_count();
for (size_t i = 0; i < num_batches; i++)
RETURN_NOT_OK(files_[to_spill].SpillBatch(ctx_, std::move(queues_[to_spill][i])));
RETURN_NOT_OK(partition.file.SpillBatch(ctx_, std::move(partition.queue[i])));
return true;
}

Expand All @@ -207,7 +210,7 @@ Result<bool> SpillingAccumulationQueue::AdvanceHashCursor() {
}

auto lock = partition_locks_.AcquirePartitionLock(static_cast<int>(to_spill));
hash_queues_[to_spill].Clear();
partitions_[to_spill].hash_queue.Clear();
return true;
}
} // namespace compute
Expand Down
88 changes: 64 additions & 24 deletions cpp/src/arrow/compute/exec/accumulation_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ class AccumulationQueue {
std::vector<ExecBatch> batches_;
};

/// Accumulates batches in a queue that can be spilled to disk if needed
///
/// Each batch is partitioned by the lower bits of the hash column (which must be present)
/// and rows are initially accumulated in batch builders (one per partition). As a batch
/// builder fills up the completed batch is put into an in-memory accumulation queue (per
/// partition).
///
/// When memory pressure is encountered the spilling queue's "spill cursor" can be
/// advanced. This will cause a partition to be spilled to disk. Any future data
/// arriving for that partition will go immediately to disk (after accumulating a full
/// batch in the batch builder). Note that hashes are spilled separately from batches and
/// have their own cursor. We assume that the Batch cursor is advanced faster than the
/// spill cursor. Hashes are spilled separately to enable building a Bloom filter for
/// spilled partitions.
///
/// Later, data is retrieved one partition at a time. Partitions that are in-memory will
/// be delivered immediately in new thread tasks. Partitions that are on disk will be
/// read from disk and delivered as they arrive.
///
/// This class assumes that data is fully accumulated before it is read-back. As such, do
/// not call InsertBatch after calling GetPartition.
class SpillingAccumulationQueue {
public:
// Number of partitions must be a power of two, since we assign partitions by
Expand All @@ -72,39 +93,57 @@ class SpillingAccumulationQueue {
Status Init(QueryContext* ctx);
// Assumes that the final column in batch contains 64-bit hashes of the columns.
Status InsertBatch(size_t thread_index, ExecBatch batch);
Status GetPartition(size_t thread_index, size_t partition,
// Runs `on_batch` on each batch in the SpillingAccumulationQueue for the given
// partition. Each batch will have its own task. Once all batches have had their
// on_batch function run, `on_finished` will be called.
Status GetPartition(size_t thread_index, size_t partition_idx,
std::function<Status(size_t, size_t, ExecBatch)>
on_batch, // thread_index, batch_index, batch
std::function<Status(size_t)> on_finished);

// Returns hashes of the given partition and batch index.
// partition MUST be at least hash_cursor, as if partition < hash_cursor,
// these hashes will have been deleted.
const uint64_t* GetHashes(size_t partition, size_t batch_idx);
inline size_t batch_count(size_t partition) const {
size_t num_full_batches = partition >= spilling_cursor_
? queues_[partition].batch_count()
: files_[partition].num_batches();

return num_full_batches + (builders_[partition].num_rows() > 0);
const uint64_t* GetHashes(size_t partition_idx, size_t batch_idx);
inline size_t batch_count(size_t partition_idx) const {
const Partition& partition = partitions_[partition_idx];
size_t num_full_batches = partition_idx >= spilling_cursor_
? partition.queue.batch_count()
: partition.file.num_batches();

return num_full_batches + (partition.builder.num_rows() > 0);
}
inline size_t row_count(size_t partition, size_t batch_idx) const {
if (batch_idx < hash_queues_[partition].batch_count())
return hash_queues_[partition][batch_idx].length;

inline size_t row_count(size_t partition_idx, size_t batch_idx) const {
const Partition& partition = partitions_[partition_idx];
if (batch_idx < partition.hash_queue.batch_count())
return partition.hash_queue[batch_idx].length;
else
return builders_[partition].num_rows();
return partition.builder.num_rows();
}

static inline constexpr size_t partition_id(uint64_t hash) {
// Hash Table uses the top bits of the hash, so we really really
// need to use the bottom bits of the hash for spilling to avoid
// Hash Table uses the top bits of the hash, so it is important
// to use the bottom bits of the hash for spilling to avoid
// a huge number of hash collisions per partition.
return static_cast<size_t>(hash & (kNumPartitions - 1));
}

// Returns the row count for the partition if it is still in-memory.
// Returns 0 if the partition has already been spilled.
size_t CalculatePartitionRowCount(size_t partition) const;

// Spills the next partition of batches to disk and returns true,
// or returns false if too many partitions have been spilled.
// The QueryContext's bytes_in_flight will be increased by the
// number of bytes spilled (unless the disk IO was very fast and
// the bytes_in_flight got reduced again).
//
// We expect that we always advance the SpillCursor faster than the
// HashCursor, and only advance the HashCursor when we've exhausted
// partitions for the SpillCursor.
Result<bool> AdvanceSpillCursor();
// Same as AdvanceSpillCursor but spills the hashes for the partition.
Result<bool> AdvanceHashCursor();
inline size_t spill_cursor() const { return spilling_cursor_.load(); }
inline size_t hash_cursor() const { return hash_cursor_.load(); }
Expand All @@ -116,16 +155,17 @@ class SpillingAccumulationQueue {
QueryContext* ctx_;
PartitionLocks partition_locks_;

AccumulationQueue queues_[kNumPartitions];
AccumulationQueue hash_queues_[kNumPartitions];

ExecBatchBuilder builders_[kNumPartitions];

SpillFile files_[kNumPartitions];

int task_group_read_[kNumPartitions];
std::function<Status(size_t, size_t, ExecBatch)> read_back_fn_[kNumPartitions];
std::function<Status(size_t)> on_finished_[kNumPartitions];
struct Partition {
AccumulationQueue queue;
AccumulationQueue hash_queue;
ExecBatchBuilder builder;
SpillFile file;
int task_group_read;
std::function<Status(size_t, size_t, ExecBatch)> read_back_fn;
std::function<Status(size_t)> on_finished;
};

Partition partitions_[kNumPartitions];
};

} // namespace compute
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,16 @@ Result<std::vector<std::shared_ptr<RecordBatch>>> DeclarationToBatches(

Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(Declaration declaration,
ExecContext exec_context) {
return DeclarationToExecBatchesAsync(std::move(declaration), exec_context,
QueryOptions{});
}

Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(
Declaration declaration, ExecContext exec_context, QueryOptions query_options) {
std::shared_ptr<Schema> out_schema;
AsyncGenerator<std::optional<ExecBatch>> sink_gen;
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> exec_plan,
ExecPlan::Make(exec_context));
ExecPlan::Make(query_options, exec_context));
Declaration with_sink = Declaration::Sequence(
{declaration, {"sink", SinkNodeOptions(&sink_gen, &out_schema)}});
ARROW_RETURN_NOT_OK(with_sink.AddToPlan(exec_plan.get()));
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,13 @@ ARROW_EXPORT Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(
ARROW_EXPORT Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(
Declaration declaration, ExecContext custom_exec_context);

/// \brief Overload of \see DeclarationToExecBatchesAsync accepting a custom exec context
/// and QueryOptions
///
/// \see DeclarationToTableAsync for details on threading & execution
ARROW_EXPORT Future<BatchesWithCommonSchema> DeclarationToExecBatchesAsync(
Declaration declaration, ExecContext custom_exec_context, QueryOptions query_options);

/// \brief Utility method to run a declaration and collect the results into a vector
///
/// \see DeclarationToTable for details on threading & execution
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/hash_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ static void HashJoinBasicBenchmarkImpl(benchmark::State& st,
BenchmarkSettings& settings) {
uint64_t total_rows = 0;
for (auto _ : st) {
st.PauseTiming();
{
JoinBenchmark bm(settings);
st.ResumeTiming();
bm.RunJoin();
st.PauseTiming();
total_rows += bm.stats_.num_probe_rows;
st.PauseTiming();
}
st.ResumeTiming();
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/query_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace arrow {
using internal::CpuInfo;
namespace compute {
QueryOptions::QueryOptions()
: max_memory_bytes(::arrow::internal::GetTotalMemoryBytes()),
: max_memory_bytes(::arrow::internal::GetTotalMemoryBytes() / 2),
use_legacy_batching(false) {}

QueryContext::QueryContext(QueryOptions opts, ExecContext exec_context)
Expand Down
Loading

0 comments on commit 5cb8c50

Please sign in to comment.