Skip to content

Commit

Permalink
apacheGH-43759: [C++] Acero: Minor code enhancement for Join (apache#…
Browse files Browse the repository at this point in the history
…43760)

### Rationale for this change

Minor style enhancement for join

### What changes are included in this PR?

Minor style enhancement for join

### Are these changes tested?

Covered by existing

### Are there any user-facing changes?

no

* GitHub Issue: apache#43759

Authored-by: mwish <[email protected]>
Signed-off-by: mwish <[email protected]>
  • Loading branch information
mapleFU authored Aug 29, 2024
1 parent 45592f9 commit 4f91c8f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 59 deletions.
9 changes: 4 additions & 5 deletions cpp/src/arrow/acero/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,20 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr<Array> dictiona
return Status::OK();
}

dictionary_ = dictionary;
dictionary_ = std::move(dictionary);

// Initialize encoder
RowEncoder encoder;
std::vector<TypeHolder> encoder_types;
encoder_types.emplace_back(value_type_);
std::vector<TypeHolder> encoder_types{value_type_};
encoder.Init(encoder_types, ctx);

// Encode all dictionary values
int64_t length = dictionary->data()->length;
int64_t length = dictionary_->data()->length;
if (length >= std::numeric_limits<int32_t>::max()) {
return Status::Invalid(
"Dictionary length in hash join must fit into signed 32-bit integer.");
}
RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary->data()}, length)));
RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary_->data()}, length)));

std::vector<int32_t> entries_to_take;

Expand Down
16 changes: 8 additions & 8 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,30 @@ Result<std::vector<FieldRef>> HashJoinSchema::ComputePayload(
const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
// payload = (output + filter) - keys, with no duplicates
std::unordered_set<int> payload_fields;
for (auto ref : output) {
for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}

for (auto ref : filter) {
for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}

for (auto ref : keys) {
for (const auto& ref : keys) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.erase(match[0]);
}

std::vector<FieldRef> payload_refs;
for (auto ref : output) {
for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
payload_fields.erase(match[0]);
}
}
for (auto ref : filter) {
for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
Expand Down Expand Up @@ -198,7 +198,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc
return Status::Invalid("Different number of key fields on left (", left_keys.size(),
") and right (", right_keys.size(), ") side of the join");
}
if (left_keys.size() < 1) {
if (left_keys.empty()) {
return Status::Invalid("Join key cannot be empty");
}
for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) {
Expand Down Expand Up @@ -432,7 +432,7 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
indices[0] -= left_schema.num_fields();
FieldPath corrected_path(std::move(indices));
if (right_seen_paths.find(*path) == right_seen_paths.end()) {
right_filter.push_back(corrected_path);
right_filter.emplace_back(corrected_path);
right_seen_paths.emplace(std::move(corrected_path));
}
} else if (left_seen_paths.find(*path) == left_seen_paths.end()) {
Expand Down Expand Up @@ -698,7 +698,7 @@ class HashJoinNode : public ExecNode, public TracedNode {
std::shared_ptr<Schema> output_schema,
std::unique_ptr<HashJoinSchema> schema_mgr, Expression filter,
std::unique_ptr<HashJoinImpl> impl)
: ExecNode(plan, inputs, {"left", "right"},
: ExecNode(plan, std::move(inputs), {"left", "right"},
/*output_schema=*/std::move(output_schema)),
TracedNode(this),
join_type_(join_options.join_type),
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/acero/hash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
const std::string& right_field_name_suffix);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }
bool LeftPayloadIsEmpty() const { return PayloadIsEmpty(0); }

bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }
bool RightPayloadIsEmpty() const { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
Expand All @@ -88,7 +88,7 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
const SchemaProjectionMap& right_to_filter,
const Expression& filter);

bool PayloadIsEmpty(int side) {
bool PayloadIsEmpty(int side) const {
assert(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ Result<std::shared_ptr<ArrayData>> JoinResultMaterialize::FlushBuildColumn(
const std::shared_ptr<DataType>& data_type, const RowArray* row_array, int column_id,
uint32_t* row_ids) {
ResizableArrayData output;
output.Init(data_type, pool_, bit_util::Log2(num_rows_));
RETURN_NOT_OK(output.Init(data_type, pool_, bit_util::Log2(num_rows_)));

for (size_t i = 0; i <= null_ranges_.size(); ++i) {
int row_id_begin =
Expand Down Expand Up @@ -2247,8 +2247,9 @@ Result<ExecBatch> JoinResidualFilter::MaterializeFilterInput(
build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
for (int i = 0; i < num_build_cols; ++i) {
ResizableArrayData column_data;
column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i), pool_,
bit_util::Log2(num_batch_rows));
RETURN_NOT_OK(
column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i),
pool_, bit_util::Log2(num_batch_rows)));
if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) {
RETURN_NOT_OK(build_keys_->DecodeSelected(&column_data, idx, num_batch_rows,
key_ids_maybe_null, pool_));
Expand Down
68 changes: 32 additions & 36 deletions cpp/src/arrow/compute/light_array_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,9 @@ Result<KeyColumnMetadata> ColumnMetadataFromDataType(
const std::shared_ptr<DataType>& type) {
const bool is_extension = type->id() == Type::EXTENSION;
const std::shared_ptr<DataType>& typ =
is_extension
? arrow::internal::checked_pointer_cast<ExtensionType>(type->GetSharedPtr())
->storage_type()
: type;
is_extension ? arrow::internal::checked_cast<const ExtensionType*>(type.get())
->storage_type()
: type;

if (typ->id() == Type::DICTIONARY) {
auto bit_width =
Expand Down Expand Up @@ -205,22 +204,25 @@ Status ColumnArraysFromExecBatch(const ExecBatch& batch,
column_arrays);
}

void ResizableArrayData::Init(const std::shared_ptr<DataType>& data_type,
MemoryPool* pool, int log_num_rows_min) {
Status ResizableArrayData::Init(const std::shared_ptr<DataType>& data_type,
MemoryPool* pool, int log_num_rows_min) {
#ifndef NDEBUG
if (num_rows_allocated_ > 0) {
ARROW_DCHECK(data_type_ != NULLPTR);
KeyColumnMetadata metadata_before =
ColumnMetadataFromDataType(data_type_).ValueOrDie();
KeyColumnMetadata metadata_after = ColumnMetadataFromDataType(data_type).ValueOrDie();
ARROW_DCHECK(data_type_ != nullptr);
const KeyColumnMetadata& metadata_before = column_metadata_;
ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata_after,
ColumnMetadataFromDataType(data_type));
ARROW_DCHECK(metadata_before.is_fixed_length == metadata_after.is_fixed_length &&
metadata_before.fixed_length == metadata_after.fixed_length);
}
#endif
ARROW_DCHECK(data_type != nullptr);
ARROW_ASSIGN_OR_RAISE(column_metadata_, ColumnMetadataFromDataType(data_type));
Clear(/*release_buffers=*/false);
log_num_rows_min_ = log_num_rows_min;
data_type_ = data_type;
pool_ = pool;
return Status::OK();
}

void ResizableArrayData::Clear(bool release_buffers) {
Expand All @@ -246,8 +248,6 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
num_rows_allocated_new *= 2;
}

KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

if (buffers_[kFixedLengthBuffer] == NULLPTR) {
ARROW_DCHECK(buffers_[kValidityBuffer] == NULLPTR &&
buffers_[kVariableLengthBuffer] == NULLPTR);
Expand All @@ -258,8 +258,8 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes, pool_));
memset(mutable_data(kValidityBuffer), 0,
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes);
if (column_metadata.is_fixed_length) {
if (column_metadata.fixed_length == 0) {
if (column_metadata_.is_fixed_length) {
if (column_metadata_.fixed_length == 0) {
ARROW_ASSIGN_OR_RAISE(
buffers_[kFixedLengthBuffer],
AllocateResizableBuffer(
Expand All @@ -271,7 +271,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
ARROW_ASSIGN_OR_RAISE(
buffers_[kFixedLengthBuffer],
AllocateResizableBuffer(
num_rows_allocated_new * column_metadata.fixed_length + kNumPaddingBytes,
num_rows_allocated_new * column_metadata_.fixed_length + kNumPaddingBytes,
pool_));
}
} else {
Expand Down Expand Up @@ -300,15 +300,15 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
memset(mutable_data(kValidityBuffer) + bytes_for_bits_before, 0,
bytes_for_bits_after - bytes_for_bits_before);

if (column_metadata.is_fixed_length) {
if (column_metadata.fixed_length == 0) {
if (column_metadata_.is_fixed_length) {
if (column_metadata_.fixed_length == 0) {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes));
memset(mutable_data(kFixedLengthBuffer) + bytes_for_bits_before, 0,
bytes_for_bits_after - bytes_for_bits_before);
} else {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
num_rows_allocated_new * column_metadata.fixed_length + kNumPaddingBytes));
num_rows_allocated_new * column_metadata_.fixed_length + kNumPaddingBytes));
}
} else {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
Expand All @@ -323,10 +323,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int num_rows_new) {
}

Status ResizableArrayData::ResizeVaryingLengthBuffer() {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

if (!column_metadata.is_fixed_length) {
if (!column_metadata_.is_fixed_length) {
int64_t min_new_size = buffers_[kFixedLengthBuffer]->data_as<int32_t>()[num_rows_];
ARROW_DCHECK(var_len_buf_size_ > 0);
if (var_len_buf_size_ < min_new_size) {
Expand All @@ -343,23 +340,19 @@ Status ResizableArrayData::ResizeVaryingLengthBuffer() {
}

KeyColumnArray ResizableArrayData::column_array() const {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
return KeyColumnArray(column_metadata, num_rows_,
return KeyColumnArray(column_metadata_, num_rows_,
buffers_[kValidityBuffer]->mutable_data(),
buffers_[kFixedLengthBuffer]->mutable_data(),
buffers_[kVariableLengthBuffer]->mutable_data());
}

std::shared_ptr<ArrayData> ResizableArrayData::array_data() const {
KeyColumnMetadata column_metadata;
column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();

auto valid_count = arrow::internal::CountSetBits(
buffers_[kValidityBuffer]->data(), /*offset=*/0, static_cast<int64_t>(num_rows_));
auto valid_count =
arrow::internal::CountSetBits(buffers_[kValidityBuffer]->data(), /*bit_offset=*/0,
static_cast<int64_t>(num_rows_));
int null_count = static_cast<int>(num_rows_) - static_cast<int>(valid_count);

if (column_metadata.is_fixed_length) {
if (column_metadata_.is_fixed_length) {
return ArrayData::Make(data_type_, num_rows_,
{buffers_[kValidityBuffer], buffers_[kFixedLengthBuffer]},
null_count);
Expand Down Expand Up @@ -493,10 +486,12 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
ARROW_DCHECK(num_rows_before >= 0);
int num_rows_after = num_rows_before + num_rows_to_append;
if (target->num_rows() == 0) {
target->Init(source->type, pool, kLogNumRows);
RETURN_NOT_OK(target->Init(source->type, pool, kLogNumRows));
}
RETURN_NOT_OK(target->ResizeFixedLengthBuffers(num_rows_after));

// Since target->Init is called before, we can assume that the ColumnMetadata
// would never fail to be created
KeyColumnMetadata column_metadata =
ColumnMetadataFromDataType(source->type).ValueOrDie();

Expand Down Expand Up @@ -647,11 +642,12 @@ Status ExecBatchBuilder::AppendNulls(const std::shared_ptr<DataType>& type,
int num_rows_before = target.num_rows();
int num_rows_after = num_rows_before + num_rows_to_append;
if (target.num_rows() == 0) {
target.Init(type, pool, kLogNumRows);
RETURN_NOT_OK(target.Init(type, pool, kLogNumRows));
}
RETURN_NOT_OK(target.ResizeFixedLengthBuffers(num_rows_after));

KeyColumnMetadata column_metadata = ColumnMetadataFromDataType(type).ValueOrDie();
ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata column_metadata,
ColumnMetadataFromDataType(type));

// Process fixed length buffer
//
Expand Down Expand Up @@ -708,7 +704,7 @@ Status ExecBatchBuilder::AppendSelected(MemoryPool* pool, const ExecBatch& batch
const Datum& data = batch.values[col_ids ? col_ids[i] : i];
ARROW_DCHECK(data.is_array());
const std::shared_ptr<ArrayData>& array_data = data.array();
values_[i].Init(array_data->type, pool, kLogNumRows);
RETURN_NOT_OK(values_[i].Init(array_data->type, pool, kLogNumRows));
}
}

Expand Down Expand Up @@ -739,7 +735,7 @@ Status ExecBatchBuilder::AppendNulls(MemoryPool* pool,
if (values_.empty()) {
values_.resize(types.size());
for (size_t i = 0; i < types.size(); ++i) {
values_[i].Init(types[i], pool, kLogNumRows);
RETURN_NOT_OK(values_[i].Init(types[i], pool, kLogNumRows));
}
}

Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/compute/light_array_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ class ARROW_EXPORT ResizableArrayData {
/// \param pool The pool to make allocations on
/// \param log_num_rows_min All resize operations will allocate at least enough
/// space for (1 << log_num_rows_min) rows
void Init(const std::shared_ptr<DataType>& data_type, MemoryPool* pool,
int log_num_rows_min);
Status Init(const std::shared_ptr<DataType>& data_type, MemoryPool* pool,
int log_num_rows_min);

/// \brief Resets the array back to an empty state
/// \param release_buffers If true then allocated memory is released and the
Expand Down Expand Up @@ -351,6 +351,8 @@ class ARROW_EXPORT ResizableArrayData {
static constexpr int64_t kNumPaddingBytes = 64;
int log_num_rows_min_;
std::shared_ptr<DataType> data_type_;
// Would be valid if data_type_ != NULLPTR.
KeyColumnMetadata column_metadata_{};
MemoryPool* pool_;
int num_rows_;
int num_rows_allocated_;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/light_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ TEST(ResizableArrayData, Basic) {
arrow::internal::checked_pointer_cast<FixedWidthType>(type)->bit_width() / 8;
{
ResizableArrayData array;
array.Init(type, pool.get(), /*log_num_rows_min=*/16);
ASSERT_OK(array.Init(type, pool.get(), /*log_num_rows_min=*/16));
ASSERT_EQ(0, array.num_rows());
ASSERT_OK(array.ResizeFixedLengthBuffers(2));
ASSERT_EQ(2, array.num_rows());
Expand Down Expand Up @@ -330,7 +330,7 @@ TEST(ResizableArrayData, Binary) {
ARROW_SCOPED_TRACE("Type: ", type->ToString());
{
ResizableArrayData array;
array.Init(type, pool.get(), /*log_num_rows_min=*/4);
ASSERT_OK(array.Init(type, pool.get(), /*log_num_rows_min=*/4));
ASSERT_EQ(0, array.num_rows());
ASSERT_OK(array.ResizeFixedLengthBuffers(2));
ASSERT_EQ(2, array.num_rows());
Expand Down

0 comments on commit 4f91c8f

Please sign in to comment.