Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-610] hashagg opt#1 #715

Merged
merged 6 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -763,28 +763,29 @@ class HashAggregateKernel::Impl {

std::vector<int> indices;
indices.resize(length, -1);

for (int i = 0; i < length; i++) {
auto aggr_key = typed_key_in->GetView(i);
auto aggr_key_validity =
typed_key_in->null_count() == 0 ? true : !typed_key_in->IsNull(i);

// 3. get key from hash_table
int memo_index = 0;
if (!aggr_key_validity) {
memo_index = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {});
} else {
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &memo_index);
if (typed_key_in->null_count() > 0) {
for (int i = 0; i < length; i++) {
auto aggr_key = typed_key_in->GetView(i);
auto aggr_key_validity = !typed_key_in->IsNull(i);

if (aggr_key_validity) {
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
} else {
indices[i] = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {});
}
}

if (memo_index > max_group_id_) {
max_group_id_ = memo_index;
} else {
for (int i = 0; i < length; i++) {
auto aggr_key = typed_key_in->GetView(i);
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
}
indices[i] = memo_index;
}

max_group_id_ = aggr_hash_table_->size_ - 1;
total_out_length_ = max_group_id_ + 1;

// 4. prepare action func and evaluate
std::vector<std::function<arrow::Status(int)>> eval_func_list;
std::vector<std::function<arrow::Status()>> eval_null_func_list;
Expand All @@ -802,15 +803,9 @@ class HashAggregateKernel::Impl {
eval_null_func_list.push_back(null_func);
}

for (auto memo_index : indices) {
if (memo_index == -1) {
for (auto eval_func : eval_null_func_list) {
RETURN_NOT_OK(eval_func());
}
} else {
for (auto eval_func : eval_func_list) {
RETURN_NOT_OK(eval_func(memo_index));
}
for (auto eval_func : eval_func_list) {
for (auto memo_index : indices) {
RETURN_NOT_OK(eval_func(memo_index));
}
}

Expand Down Expand Up @@ -961,15 +956,9 @@ class HashAggregateKernel::Impl {
eval_null_func_list.push_back(null_func);
}

for (auto memo_index : indices) {
if (memo_index == -1) {
for (auto eval_func : eval_null_func_list) {
RETURN_NOT_OK(eval_func());
}
} else {
for (auto eval_func : eval_func_list) {
RETURN_NOT_OK(eval_func(memo_index));
}
for (auto eval_func : eval_func_list) {
for (auto memo_index : indices) {
RETURN_NOT_OK(eval_func(memo_index));
}
}
return arrow::Status::OK();
Expand Down Expand Up @@ -1115,15 +1104,9 @@ class HashAggregateKernel::Impl {
eval_null_func_list.push_back(null_func);
}

for (auto memo_index : indices) {
if (memo_index == -1) {
for (auto eval_func : eval_null_func_list) {
RETURN_NOT_OK(eval_func());
}
} else {
for (auto eval_func : eval_func_list) {
RETURN_NOT_OK(eval_func(memo_index));
}
for (auto eval_func : eval_func_list) {
for (auto memo_index : indices) {
RETURN_NOT_OK(eval_func(memo_index));
}
}

Expand Down
4 changes: 2 additions & 2 deletions native-sql-engine/cpp/src/precompile/hash_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace precompile {
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
class HASHMAPNAME::Impl : public MEMOTABLETYPE { \
public: \
Impl(arrow::MemoryPool* pool) : MEMOTABLETYPE(pool) {} \
Impl(arrow::MemoryPool* pool) : MEMOTABLETYPE(pool, 128) {} \
}; \
\
HASHMAPNAME::HASHMAPNAME(arrow::MemoryPool* pool) { \
Expand Down Expand Up @@ -107,6 +107,6 @@ TYPED_ARROW_HASH_MAP_IMPL(StringHashMap, StringType, arrow::util::string_view,
TYPED_ARROW_HASH_MAP_DECIMAL_IMPL(Decimal128HashMap, Decimal128Type, arrow::Decimal128,
DecimalMemoTableType)
#undef TYPED_ARROW_HASH_MAP_IMPL

#undef TYPED_ARROW_HASH_MAP_DECIMAL_IMPL
} // namespace precompile
} // namespace sparkcolumnarplugin
2 changes: 1 addition & 1 deletion native-sql-engine/cpp/src/precompile/unsafe_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class TypedUnsafeArray<DataType, enable_if_string_like<DataType>> : public Unsaf
if (!skip_null_check_ && typed_array_->IsNull(i)) {
setNullAt((*unsafe_row).get(), idx_);
} else {
auto v = typed_array_->GetString(i);
auto v = typed_array_->GetView(i);
appendToUnsafeRow((*unsafe_row).get(), idx_, v);
}
return arrow::Status::OK();
Expand Down
Loading