Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-610] hashagg opt#2 #735

Merged
merged 7 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class UniqueAction : public ActionBase {
if (cache_validity_.size() <= max_group_id) {
cache_validity_.resize(max_group_id + 1, false);
null_flag_.resize(max_group_id + 1, false);
cache_.resize(max_group_id + 1);
length_ = cache_validity_.size();
}

Expand All @@ -156,12 +157,10 @@ class UniqueAction : public ActionBase {
if (cache_validity_[dest_group_id] == false) {
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_.emplace(cache_.begin() + dest_group_id, in_->GetView(row_id_));
cache_[dest_group_id] = (CType)in_->GetView(row_id_);
} else {
cache_validity_[dest_group_id] = true;
null_flag_[dest_group_id] = true;
CType num;
cache_.emplace(cache_.begin() + dest_group_id, num);
}
}
row_id_++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -777,9 +777,8 @@ class HashAggregateKernel::Impl {
}
} else {
for (int i = 0; i < length; i++) {
auto aggr_key = typed_key_in->GetView(i);
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
typed_key_in->GetView(i), [](int) {}, [](int) {}, &(indices[i]));
}
}

Expand Down Expand Up @@ -906,38 +905,44 @@ class HashAggregateKernel::Impl {
auto length = in[0]->length();
std::vector<int> indices;
indices.resize(length, -1);
for (int i = 0; i < length; i++) {
auto aggr_key_validity = true;
arrow::util::string_view aggr_key;
if (aggr_key_unsafe_row) {

arrow::util::string_view aggr_key;
if (aggr_key_unsafe_row) {
for (int i = 0; i < length; i++) {
aggr_key_unsafe_row->reset();
int idx = 0;

for (auto payload_arr : payloads) {
payload_arr->Append(i, &aggr_key_unsafe_row);
}
aggr_key = arrow::util::string_view(aggr_key_unsafe_row->data,
aggr_key_unsafe_row->cursor);
} else {
aggr_key = typed_key_in->GetView(i);
aggr_key_validity =
typed_key_in->null_count() == 0 ? true : !typed_key_in->IsNull(i);
}

// 3. get key from hash_table
int memo_index = 0;
if (!aggr_key_validity) {
memo_index = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {});
} else {
// FIXME(): all keys are null?
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &memo_index);
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
}
} else {
for (int i = 0; i < length; i++) {
if (typed_key_in->null_count() > 0) {
aggr_key = typed_key_in->GetView(i);
auto aggr_key_validity =
typed_key_in->null_count() == 0 ? true : !typed_key_in->IsNull(i);

if (!aggr_key_validity) {
indices[i] = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {});
} else {
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
}
} else {
aggr_key = typed_key_in->GetView(i);

if (memo_index > max_group_id_) {
max_group_id_ = memo_index;
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
}
}
indices[i] = memo_index;
}

max_group_id_ = aggr_hash_table_->Size() - 1;
total_out_length_ = max_group_id_ + 1;
// 4. prepare action func and evaluate
std::vector<std::function<arrow::Status(int)>> eval_func_list;
Expand Down
1 change: 1 addition & 0 deletions native-sql-engine/cpp/src/jni/jni_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <arrow/builder.h>
#include <arrow/pretty_print.h>
Expand Down
1 change: 1 addition & 0 deletions native-sql-engine/cpp/src/precompile/hash_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ namespace precompile {
void (*on_not_found)(int32_t)) { \
return impl_->GetOrInsertNull(on_found, on_not_found); \
} \
int32_t HASHMAPNAME::Size() { return impl_->size(); } \
int32_t HASHMAPNAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t HASHMAPNAME::GetNull() { return impl_->GetNull(); }

Expand Down
1 change: 1 addition & 0 deletions native-sql-engine/cpp/src/precompile/hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace precompile {
void (*on_not_found)(int32_t), int32_t* out_memo_index); \
int32_t GetOrInsertNull(void (*on_found)(int32_t), void (*on_not_found)(int32_t)); \
int32_t Get(const TYPE& value); \
int32_t Size(); \
int32_t GetNull(); \
\
private: \
Expand Down
2 changes: 2 additions & 0 deletions native-sql-engine/cpp/src/tests/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
* limitations under the License.
*/

#pragma once

#include <arrow/array.h>
#include <arrow/buffer.h>
#include <arrow/ipc/json_simple.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,10 @@ class SparseHashMap<Scalar, std::enable_if_t<!std::is_floating_point<Scalar>::va
template <typename Func1, typename Func2>
arrow::Status GetOrInsert(const Scalar& value, Func1&& on_found, Func2&& on_not_found,
int32_t* out_memo_index) {
if (dense_map_.find(value) == dense_map_.end()) {
auto index = size_++;
dense_map_[value] = index;
*out_memo_index = index;
on_not_found(index);
} else {
auto index = dense_map_[value];
*out_memo_index = index;
on_found(index);
auto it = dense_map_.emplace(value, size_);
*out_memo_index = it.first->second;
if (it.second) {
size_++;
}
return arrow::Status::OK();
}
Expand Down