Skip to content

Commit

Permalink
try to make ci happy
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Sep 2, 2024
1 parent f5f0f33 commit 7ed5df4
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 36 deletions.
23 changes: 15 additions & 8 deletions cpp/src/arrow/acero/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,29 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::string ToString() const override { return "HashJoinBasicImpl"; }

private:
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
Status InitEncoder(int side, HashJoinProjection projection_handle,
RowEncoder* encoder) {
std::vector<TypeHolder> data_types;
int num_cols = schema_[side]->num_cols(projection_handle);
data_types.resize(num_cols);
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] = schema_[side]->data_type(projection_handle, icol);
}
encoder->Init(data_types, ctx_->exec_context());
RETURN_NOT_OK(encoder->Init(data_types, ctx_->exec_context()));
encoder->Clear();
return Status::OK();
}

Status InitLocalStateIfNeeded(size_t thread_index) {
DCHECK_LT(thread_index, local_states_.size());
ThreadLocalState& local_state = local_states_[thread_index];
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
RETURN_NOT_OK(
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys));
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
RETURN_NOT_OK(InitEncoder(0, HashJoinProjection::PAYLOAD,
&local_state.exec_batch_payloads));
}
local_state.is_initialized = true;
}
Expand Down Expand Up @@ -512,8 +516,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();

bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
thread_index, *schema_[0], *schema_[1], ctx_->exec_context());
ARROW_ASSIGN_OR_RAISE(
bool use_key_batch_for_dicts,
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1],
ctx_->exec_context()));
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
RETURN_NOT_OK(dict_probe_.EncodeBatch(
Expand Down Expand Up @@ -563,10 +569,11 @@ class HashJoinBasicImpl : public HashJoinImpl {

Status BuildHashTable_exec_task(size_t thread_index, int64_t /*task_id*/) {
AccumulationQueue batches = std::move(build_batches_);
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_->exec_context());
RETURN_NOT_OK(
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_->exec_context()));
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
RETURN_NOT_OK(InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_));
}
hash_table_empty_ = true;

Expand Down
36 changes: 20 additions & 16 deletions cpp/src/arrow/acero/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr<Array> dictiona
// Initialize encoder
RowEncoder encoder;
std::vector<TypeHolder> encoder_types{value_type_};
encoder.Init(encoder_types, ctx);
RETURN_NOT_OK(encoder.Init(encoder_types, ctx));

// Encode all dictionary values
int64_t length = dictionary_->data()->length;
Expand Down Expand Up @@ -290,7 +290,7 @@ Result<std::shared_ptr<ArrayData>> HashJoinDictBuild::RemapInputValues(
//
RowEncoder encoder;
std::vector<TypeHolder> encoder_types = {value_type_};
encoder.Init(encoder_types, ctx);
RETURN_NOT_OK(encoder.Init(encoder_types, ctx));

// Encode all
//
Expand Down Expand Up @@ -426,7 +426,7 @@ Result<std::shared_ptr<ArrayData>> HashJoinDictProbe::RemapInput(
opt_build_side->RemapInputValues(ctx, Datum(dict->data()), dict->length()));
} else {
std::vector<TypeHolder> encoder_types = {dict_type.value_type()};
encoder_.Init(encoder_types, ctx);
RETURN_NOT_OK(encoder_.Init(encoder_types, ctx));
RETURN_NOT_OK(
encoder_.EncodeAndAppend(ExecSpan({*dict->data()}, dict->length())));
}
Expand Down Expand Up @@ -514,7 +514,7 @@ Status HashJoinDictBuildMulti::Init(
return Status::OK();
}

void HashJoinDictBuildMulti::InitEncoder(
Status HashJoinDictBuildMulti::InitEncoder(
const SchemaProjectionMaps<HashJoinProjection>& proj_map, RowEncoder* encoder,
ExecContext* ctx) {
int num_cols = proj_map.num_cols(HashJoinProjection::KEY);
Expand All @@ -525,9 +525,9 @@ void HashJoinDictBuildMulti::InitEncoder(
if (HashJoinDictBuild::KeyNeedsProcessing(data_type)) {
data_type = HashJoinDictBuild::DataTypeAfterRemapping();
}
data_types[icol] = data_type;
data_types[icol] = std::move(data_type);
}
encoder->Init(data_types, ctx);
return encoder->Init(data_types, ctx);
}

Status HashJoinDictBuildMulti::EncodeBatch(
Expand Down Expand Up @@ -568,20 +568,21 @@ Status HashJoinDictBuildMulti::PostDecode(

void HashJoinDictProbeMulti::Init(size_t num_threads) {
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
for (auto& local_state : local_states_) {
local_state.is_initialized = false;
}
}

bool HashJoinDictProbeMulti::BatchRemapNeeded(
Result<bool> HashJoinDictProbeMulti::BatchRemapNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
RETURN_NOT_OK(
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx));
DCHECK_LT(thread_index, local_states_.size());
return local_states_[thread_index].any_needs_remap;
}

void HashJoinDictProbeMulti::InitLocalStateIfNeeded(
Status HashJoinDictProbeMulti::InitLocalStateIfNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
ThreadLocalState& local_state = local_states_[thread_index];
Expand All @@ -603,11 +604,13 @@ void HashJoinDictProbeMulti::InitLocalStateIfNeeded(
}

if (local_state.any_needs_remap) {
InitEncoder(proj_map_probe, proj_map_build, &local_state.post_remap_encoder, ctx);
RETURN_NOT_OK(InitEncoder(proj_map_probe, proj_map_build,
&local_state.post_remap_encoder, ctx));
}
return Status::OK();
}

void HashJoinDictProbeMulti::InitEncoder(
Status HashJoinDictProbeMulti::InitEncoder(
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, RowEncoder* encoder,
ExecContext* ctx) {
Expand All @@ -616,14 +619,14 @@ void HashJoinDictProbeMulti::InitEncoder(
for (int icol = 0; icol < num_cols; ++icol) {
std::shared_ptr<DataType> data_type =
proj_map_probe.data_type(HashJoinProjection::KEY, icol);
std::shared_ptr<DataType> build_data_type =
const std::shared_ptr<DataType>& build_data_type =
proj_map_build.data_type(HashJoinProjection::KEY, icol);
if (HashJoinDictProbe::KeyNeedsProcessing(data_type, build_data_type)) {
data_type = HashJoinDictProbe::DataTypeAfterRemapping(build_data_type);
}
data_types[icol] = data_type;
}
encoder->Init(data_types, ctx);
return encoder->Init(data_types, ctx);
}

Status HashJoinDictProbeMulti::EncodeBatch(
Expand All @@ -632,7 +635,8 @@ Status HashJoinDictProbeMulti::EncodeBatch(
const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch,
RowEncoder** out_encoder, ExecBatch* opt_out_key_batch, ExecContext* ctx) {
ThreadLocalState& local_state = local_states_[thread_index];
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
RETURN_NOT_OK(
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx));

ExecBatch projected({}, batch.length);
int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY);
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/arrow/acero/hash_join_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ class HashJoinDictBuildMulti {
public:
Status Init(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
const ExecBatch* opt_non_empty_batch, ExecContext* ctx);
static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
RowEncoder* encoder, ExecContext* ctx);
static Status InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map,
RowEncoder* encoder, ExecContext* ctx);
Status EncodeBatch(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map,
const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const;
Expand All @@ -280,10 +280,9 @@ class HashJoinDictBuildMulti {
class HashJoinDictProbeMulti {
public:
void Init(size_t num_threads);
bool BatchRemapNeeded(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
ExecContext* ctx);
Result<bool> BatchRemapNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx);
Status EncodeBatch(size_t thread_index,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
Expand All @@ -292,12 +291,13 @@ class HashJoinDictProbeMulti {
ExecContext* ctx);

private:
void InitLocalStateIfNeeded(
Status InitLocalStateIfNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx);
static void InitEncoder(const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build,
RowEncoder* encoder, ExecContext* ctx);
static Status InitEncoder(
const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, RowEncoder* encoder,
ExecContext* ctx);
struct ThreadLocalState {
bool is_initialized;
// Whether any key column needs remapping (because of dictionaries used) before doing
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ std::vector<std::shared_ptr<Array>> GenRandomUniqueRecords(
val_types.push_back(result[i]->type());
}
RowEncoder encoder;
encoder.Init(val_types, ctx);
auto s = encoder.Init(val_types, ctx);
ExecBatch batch({}, num_desired);
batch.values.resize(result.size());
for (size_t i = 0; i < result.size(); ++i) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/row/row_encoder_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(const TypeHolder& column_type
}

if (is_list(type.id())) {
auto element_type = ::arrow::checked_cast<BaseListType*>(type.type)->value_type();
auto element_type =
::arrow::checked_cast<const BaseListType*>(type.type)->value_type();
if (is_nested(element_type->id())) {
return Status::NotImplemented("Unsupported nested type in List for row encoder", type.ToString());
}
Expand Down

0 comments on commit 7ed5df4

Please sign in to comment.