From dab32360f933945da923ade0f4b35b379bcf7f57 Mon Sep 17 00:00:00 2001 From: philo Date: Thu, 25 Aug 2022 20:32:41 +0800 Subject: [PATCH] Correct the setting for length_ --- .../codegen/arrow_compute/ext/actions_impl.cc | 276 ++++++++++++------ 1 file changed, 183 insertions(+), 93 deletions(-) diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc index d644574b5..9dac96016 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc @@ -150,6 +150,10 @@ class UniqueAction : public ActionBase { if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } in_ = std::make_shared(in_list[0]); row_id_ = 0; @@ -198,13 +202,12 @@ class UniqueAction : public ActionBase { cache_validity_.resize(max_group_size, false); null_flag_.resize(max_group_size, false); cache_.resize(max_group_size + 1); - length_ = cache_validity_.size(); return arrow::Status::OK(); } arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; if (cache_validity_[dest_group_id] == false) { cache_validity_[dest_group_id] = true; @@ -215,7 +218,7 @@ class UniqueAction : public ActionBase { arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; if (cache_validity_[dest_group_id] == false) { cache_validity_[dest_group_id] = true; @@ -308,6 +311,10 @@ class CountAction : public ActionBase { if (cache_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } in_list_ = in_list; row_id = 0; @@ -356,7 +363,6 @@ class CountAction : public ActionBase { max_group_size = target_group_size * 2; } cache_.resize(max_group_size, 0); - length_ = cache_.size(); return arrow::Status::OK(); } @@ -387,7 +393,7 @@ class CountAction : public ActionBase { arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; cache_[dest_group_id] += 1; @@ -396,7 +402,7 @@ class CountAction : public ActionBase { arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -471,6 +477,10 @@ class CountDistinctAction : public ActionBase { if (cache_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } // prepare evaluate lambda *on_valid = [this](int dest_group_id) { @@ -490,7 +500,6 @@ class CountDistinctAction : public ActionBase { max_group_size = target_group_size * 2; } cache_.resize(max_group_size, 0); - length_ = cache_.size(); return arrow::Status::OK(); } @@ -540,7 +549,7 @@ class CountDistinctAction : public ActionBase { arrow::Status Evaluate(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; cache_[dest_group_id] += 1; return arrow::Status::OK(); @@ -548,7 +557,7 @@ class CountDistinctAction : public ActionBase { arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -622,6 +631,10 @@ class CountLiteralAction : public ActionBase { if (cache_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } // prepare evaluate lambda *on_valid = [this](int dest_group_id) { @@ -641,7 +654,6 @@ class CountLiteralAction : public ActionBase { max_group_size = target_group_size * 2; } cache_.resize(max_group_size, 0); - length_ = cache_.size(); return arrow::Status::OK(); } @@ -661,7 +673,7 @@ class CountLiteralAction : public ActionBase { arrow::Status Evaluate(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; cache_[dest_group_id] += 1; return arrow::Status::OK(); @@ -669,7 +681,7 @@ class CountLiteralAction : public ActionBase { arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -746,6 +758,10 @@ class MinAction> if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } GetFunction(in_list, max_group_id, on_valid, on_null); return arrow::Status::OK(); } @@ -759,7 +775,6 @@ class MinAction> } cache_validity_.resize(max_group_size, false); cache_.resize(max_group_size, 0); - length_ = cache_validity_.size(); return arrow::Status::OK(); } @@ -775,7 +790,7 @@ class MinAction> arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; GetMinResultWithGroupBy(dest_group_id, data); return arrow::Status::OK(); @@ -783,7 +798,7 @@ class MinAction> arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -1027,6 +1042,10 @@ class MinAction> if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } in_ = std::make_shared(in_list[0]); in_null_count_ = in_->null_count(); @@ -1063,7 +1082,6 @@ class MinAction> } cache_validity_.resize(max_group_size, false); cache_.resize(max_group_size, 0); - length_ = cache_validity_.size(); return arrow::Status::OK(); } @@ -1084,7 +1102,7 @@ class MinAction> arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; if (!cache_validity_[dest_group_id]) { cache_[dest_group_id] = *(CType*)data; @@ -1098,7 +1116,7 @@ class MinAction> arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -1183,6 +1201,10 @@ class MinAction> if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } in_ = std::make_shared(in_list[0]); in_null_count_ = in_->null_count(); @@ -1219,7 +1241,6 @@ class MinAction> } cache_validity_.resize(max_group_size, false); cache_.resize(max_group_size, ""); - length_ = cache_validity_.size(); return arrow::Status::OK(); } @@ -1249,7 +1270,7 @@ class MinAction> arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; if (!cache_validity_[dest_group_id]) { cache_[dest_group_id] = *(CType*)data; @@ -1263,7 +1284,7 @@ class MinAction> arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -1353,6 +1374,10 @@ class MaxAction> if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } GetFunction(in_list, max_group_id, on_valid, on_null); return arrow::Status::OK(); } @@ -1366,7 +1391,6 @@ class MaxAction> } cache_validity_.resize(max_group_size, false); cache_.resize(max_group_size, 0); - length_ = cache_validity_.size(); return arrow::Status::OK(); } @@ -1382,7 +1406,7 @@ class MaxAction> arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; GetMaxResultWithGroupBy(dest_group_id, data); return arrow::Status::OK(); @@ -1390,7 +1414,7 @@ class MaxAction> arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -1629,6 +1653,10 @@ class MaxAction> if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } in_ = std::make_shared(in_list[0]); in_null_count_ = in_->null_count(); @@ -1665,7 +1693,6 @@ class MaxAction> } cache_validity_.resize(max_group_size, false); cache_.resize(max_group_size, 0); - length_ = cache_validity_.size(); return arrow::Status::OK(); } @@ -1685,7 +1712,7 @@ class MaxAction> arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; if (!cache_validity_[dest_group_id]) { cache_[dest_group_id] = *(CType*)data; @@ -1699,7 +1726,7 @@ class MaxAction> arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -1784,6 +1811,10 @@ class MaxAction> if (cache_validity_.size() <= max_group_id) { GrowByFactor(max_group_id + 1); } + // The actual size will be used to keep result for each group. + if (length_ < max_group_id + 1) { + length_ = max_group_id + 1; + } in_ = std::make_shared(in_list[0]); in_null_count_ = in_->null_count(); @@ -1821,7 +1852,6 @@ class MaxAction> } cache_validity_.resize(max_group_size, false); cache_.resize(max_group_size, ""); - length_ = cache_validity_.size(); return arrow::Status::OK(); } @@ -1851,7 +1881,7 @@ class MaxAction> arrow::Status Evaluate(int dest_group_id, void* data) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; if (!cache_validity_[dest_group_id]) { cache_[dest_group_id] = *(CType*)data; @@ -1865,7 +1895,7 @@ class MaxAction> arrow::Status EvaluateNull(int dest_group_id) { auto target_group_size = dest_group_id + 1; - if (cache_validity_.size() <= target_group_size) GrowByFactor(target_group_size); + if (cache_validity_.size() < target_group_size) GrowByFactor(target_group_size); if (length_ < target_group_size) length_ = target_group_size; return arrow::Status::OK(); } @@ -1956,6 +1986,10 @@ class SumActionnull_count(); @@ -1998,7 +2032,6 @@ class SumAction(in_list[0]); in_null_count_ = in_->null_count(); @@ -2156,7 +2193,6 @@ class SumActionnull_count(); @@ -2325,7 +2365,6 @@ class SumActionPartial(in_list[0]); in_null_count_ = in_->null_count(); @@ -2489,7 +2532,6 @@ class SumActionPartialnull_count(); @@ -3015,7 +3067,6 @@ class SumCountAction(in_list[0]); in_null_count_ = in_->null_count(); @@ -3202,7 +3257,6 @@ class SumCountAction(in_list[0]); in_count_ = std::make_shared(in_list[1]); @@ -3554,7 +3615,6 @@ class SumCountMergeAction(in_list[0]); in_count_ = std::make_shared(in_list[1]); @@ -3899,7 +3966,6 @@ class AvgByCountAction(in_list[0]); in_null_count_ = in_->null_count(); @@ -4100,7 +4170,6 @@ class StddevSampPartialAction(in_list[0]); in_null_count_ = in_->null_count(); @@ -4331,7 +4404,6 @@ class StddevSampPartialAction(in_list[0]); in_avg_ = std::make_shared(in_list[1]); in_m2_ = std::make_shared(in_list[2]); @@ -4556,7 +4633,6 @@ class StddevSampFinalAction(in_list[0]); in_avg_ = std::make_shared(in_list[1]); @@ -4766,7 +4846,6 @@ class StddevSampFinalActionnull_count(); @@ -5000,7 +5083,6 @@ class FirstPartialAction(in_list[0]); in_null_count_ = in_->null_count(); @@ -5236,7 +5322,6 @@ class FirstPartialAction(in_list[0]); auto value_set_array = in_list[1]; @@ -5730,7 +5821,6 @@ class FirstFinalAction