Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-610] hashagg opt#3 #903

Merged
merged 9 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion native-sql-engine/cpp/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
precompile/builder.cc
precompile/array.cc
precompile/type.cc
precompile/sort.cc
precompile/hash_arrays_kernel.cc
precompile/unsafe_array.cc
precompile/gandiva_projector.cc
Expand Down
183 changes: 128 additions & 55 deletions native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,31 @@ class UniqueAction : public ActionBase {
row_id_ = 0;
in_null_count_ = in_->null_count();
// prepare evaluate lambda
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id_);
if (cache_validity_[dest_group_id] == false) {
if (!is_null) {
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id_);
if (cache_validity_[dest_group_id] == false) {
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] = (CType)in_->GetView(row_id_);
} else {
cache_validity_[dest_group_id] = true;
null_flag_[dest_group_id] = true;
}
}
row_id_++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
if (cache_validity_[dest_group_id] == false) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] = (CType)in_->GetView(row_id_);
} else {
cache_validity_[dest_group_id] = true;
null_flag_[dest_group_id] = true;
}
}
row_id_++;
return arrow::Status::OK();
};
row_id_++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id_++;
Expand Down Expand Up @@ -1802,15 +1813,25 @@ class SumAction<DataType, CType, ResDataType, ResCType,
// prepare evaluate lambda
data_ = const_cast<CType*>(in_->data()->GetValues<CType>(1));
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (!in_null_count_) {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -1952,15 +1973,24 @@ class SumAction<DataType, CType, ResDataType, ResCType,
in_null_count_ = in_->null_count();
// prepare evaluate lambda
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};
row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2108,18 +2138,29 @@ class SumActionPartial<DataType, CType, ResDataType, ResCType,

in_ = in_list[0];
in_null_count_ = in_->null_count();
// prepare evaluate lambda

data_ = const_cast<CType*>(in_->data()->GetValues<CType>(1));
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
// prepare evaluate lambda
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2263,17 +2304,28 @@ class SumActionPartial<DataType, CType, ResDataType, ResCType,

in_ = std::make_shared<ArrayType>(in_list[0]);
in_null_count_ = in_->null_count();
// prepare evaluate lambda

row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
// prepare evaluate lambda
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2785,16 +2837,26 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
// prepare evaluate lambda
data_ = const_cast<CType*>(in_->data()->GetValues<CType>(1));
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (!in_null_count_) {
*on_valid = [this](int dest_group_id) {
cache_sum_[dest_group_id] += data_[row_id];
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_sum_[dest_group_id] += data_[row_id];
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2963,16 +3025,27 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
in_null_count_ = in_->null_count();
// prepare evaluate lambda
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_sum_[dest_group_id] += in_->GetView(row_id);
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_sum_[dest_group_id] += in_->GetView(row_id);
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,29 +975,29 @@ class HashAggregateKernel::Impl {
for (int i = 0; i < length; i++) {
aggr_key_unsafe_row->reset();

for (auto payload_arr : payloads) {
for (const auto& payload_arr : payloads) {
payload_arr->Append(i, &aggr_key_unsafe_row);
}
aggr_key = arrow::util::string_view(aggr_key_unsafe_row->data,
aggr_key_unsafe_row->cursor);

// FIXME(): all keys are null?
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
aggr_key_unsafe_row->data, aggr_key_unsafe_row->cursor, [](int) {},
[](int) {}, &(indices[i]));
}
} else {
for (int i = 0; i < length; i++) {
if (typed_key_in->null_count() > 0) {
if (typed_key_in->null_count() > 0) {
for (int i = 0; i < length; i++) {
aggr_key = typed_key_in->GetView(i);
auto aggr_key_validity =
typed_key_in->null_count() == 0 ? true : !typed_key_in->IsNull(i);

if (!aggr_key_validity) {
if (typed_key_in->IsNull(i)) {
indices[i] = aggr_hash_table_->GetOrInsertNull([](int) {}, [](int) {});
} else {
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
}
} else {
}
} else {
for (int i = 0; i < length; i++) {
aggr_key = typed_key_in->GetView(i);

aggr_hash_table_->GetOrInsert(
Expand Down Expand Up @@ -1037,7 +1037,6 @@ class HashAggregateKernel::Impl {

arrow::Status Next(std::shared_ptr<arrow::RecordBatch>* out) {
uint64_t out_length = 0;
int gp_idx = 0;
std::vector<std::shared_ptr<arrow::Array>> outputs;
for (auto action : action_impl_list_) {
action->Finish(offset_, batch_size_, &outputs);
Expand Down
54 changes: 31 additions & 23 deletions native-sql-engine/cpp/src/precompile/hash_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,6 @@
namespace sparkcolumnarplugin {
namespace precompile {

#define TYPED_SPARSE_HASH_MAP_IMPL(TYPENAME, TYPE) \
class TYPENAME::Impl : public SparseHashMap<TYPE> { \
public: \
Impl(arrow::MemoryPool* pool) : SparseHashMap<TYPE>(pool) {} \
}; \
\
TYPENAME::TYPENAME(arrow::MemoryPool* pool) { impl_ = std::make_shared<Impl>(pool); } \
arrow::Status TYPENAME::GetOrInsert(const TYPE& value, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), \
int32_t* out_memo_index) { \
return impl_->GetOrInsert(value, on_found, on_not_found, out_memo_index); \
} \
int32_t TYPENAME::GetOrInsertNull(void (*on_found)(int32_t), \
void (*on_not_found)(int32_t)) { \
return impl_->GetOrInsertNull(on_found, on_not_found); \
} \
int32_t TYPENAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t TYPENAME::GetNull() { return impl_->GetNull(); }

#undef TYPED_SPARSE_HASH_MAP_IMPL

#define TYPED_ARROW_HASH_MAP_IMPL(HASHMAPNAME, TYPENAME, TYPE, MEMOTABLETYPE) \
using MEMOTABLETYPE = \
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
Expand All @@ -72,6 +51,35 @@ namespace precompile {
int32_t HASHMAPNAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t HASHMAPNAME::GetNull() { return impl_->GetNull(); }

#define TYPED_ARROW_HASH_MAP_BINARY_IMPL(HASHMAPNAME, TYPENAME, TYPE, MEMOTABLETYPE) \
using MEMOTABLETYPE = \
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
class HASHMAPNAME::Impl : public MEMOTABLETYPE { \
public: \
Impl(arrow::MemoryPool* pool) : MEMOTABLETYPE(pool, 128) {} \
}; \
\
HASHMAPNAME::HASHMAPNAME(arrow::MemoryPool* pool) { \
impl_ = std::make_shared<Impl>(pool); \
} \
arrow::Status HASHMAPNAME::GetOrInsert(const TYPE& value, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), \
int32_t* out_memo_index) { \
return impl_->GetOrInsert(value, on_found, on_not_found, out_memo_index); \
} \
arrow::Status HASHMAPNAME::GetOrInsert( \
const void* value, int len, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), int32_t* out_memo_index) { \
return impl_->GetOrInsert(value, len, on_found, on_not_found, out_memo_index); \
} \
int32_t HASHMAPNAME::GetOrInsertNull(void (*on_found)(int32_t), \
void (*on_not_found)(int32_t)) { \
return impl_->GetOrInsertNull(on_found, on_not_found); \
} \
int32_t HASHMAPNAME::Size() { return impl_->size(); } \
int32_t HASHMAPNAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t HASHMAPNAME::GetNull() { return impl_->GetNull(); }

#define TYPED_ARROW_HASH_MAP_DECIMAL_IMPL(HASHMAPNAME, TYPENAME, TYPE, MEMOTABLETYPE) \
using MEMOTABLETYPE = \
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
Expand Down Expand Up @@ -103,8 +111,8 @@ TYPED_ARROW_HASH_MAP_IMPL(FloatHashMap, FloatType, float, FloatMemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(DoubleHashMap, DoubleType, double, DoubleMemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(Date32HashMap, Date32Type, int32_t, Date32MemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(Date64HashMap, Date64Type, int64_t, Date64MemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(StringHashMap, StringType, arrow::util::string_view,
StringMemoTableType)
TYPED_ARROW_HASH_MAP_BINARY_IMPL(StringHashMap, StringType, arrow::util::string_view,
StringMemoTableType)
TYPED_ARROW_HASH_MAP_DECIMAL_IMPL(Decimal128HashMap, Decimal128Type, arrow::Decimal128,
DecimalMemoTableType)
#undef TYPED_ARROW_HASH_MAP_IMPL
Expand Down
2 changes: 2 additions & 0 deletions native-sql-engine/cpp/src/precompile/hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace precompile {
TYPENAME(arrow::MemoryPool* pool); \
arrow::Status GetOrInsert(const TYPE& value, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), int32_t* out_memo_index); \
arrow::Status GetOrInsert(const void* value, int len, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), int32_t* out_memo_index); \
int32_t GetOrInsertNull(void (*on_found)(int32_t), void (*on_not_found)(int32_t)); \
int32_t Get(const TYPE& value); \
int32_t Size(); \
Expand Down
Loading