diff --git a/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc b/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc index b8297dc9a..24ed2b691 100644 --- a/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc +++ b/cpp/src/codegen/arrow_compute/ext/sort_kernel.cc @@ -252,10 +252,14 @@ class SortArraysToIndicesKernel::Impl { indice++; } std::string cached_insert_str = GetCachedInsert( - shuffle_typed_codegen_list.size(), projected_types_.size(), key_projector_); + shuffle_typed_codegen_list.size(), projected_types_.size(), key_projector_, + key_index_list_); std::string comp_func_str = GetCompFunction(key_index_list_, key_projector_, projected_types_, key_field_list_, sort_directions_, nulls_order_); + std::string comp_func_str_without_null = + GetCompFunctionWithoutNull(key_index_list_, key_projector_, projected_types_, + key_field_list_, sort_directions_); std::string pre_sort_valid_str = GetPreSortValid(); @@ -329,6 +333,7 @@ class TypedSorterImpl : public CodeGenBase { // we should support nulls first and nulls last here // we should also support desc and asc here )" + comp_func_str + + comp_func_str_without_null + R"( // initiate buffer for all arrays std::shared_ptr indices_buf; @@ -375,6 +380,7 @@ class TypedSorterImpl : public CodeGenBase { uint64_t num_batches_ = 0; uint64_t items_total_ = 0; uint64_t nulls_total_ = 0; + bool has_null_ = false; class SortRelationResultIterator : public ResultIterator { public: @@ -447,7 +453,8 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, )"; } std::string GetCachedInsert(int shuffle_size, int projected_size, - const std::shared_ptr& key_projector) { + const std::shared_ptr& key_projector, + const std::vector& sort_key_index_list) { std::stringstream ss; for (int i = 0; i < shuffle_size; i++) { ss << "cached_" << i << "_.push_back(std::make_shared(in[" @@ -459,6 +466,20 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, << ">(projected_batch[" << i << "]));" << std::endl; } } + if (key_projector) { + for (int i = 0; i < projected_size; i++) { + ss << "if (!has_null_ && projected_" << i + << "_[projected_0_.size() - 1]->null_count() > 0) { " << "has_null_ = true;}" + << std::endl; + } + } else { + for (int i = 0; i < sort_key_index_list.size(); i++) { + int key_id = sort_key_index_list[i]; + ss << "if (!has_null_ && cached_" << key_id << "_[cached_" << key_id + << "_.size() - 1]->null_count() > 0) {" + << "has_null_ = true;}" << std::endl; + } + } return ss.str(); } std::string GetCompFunction( @@ -477,7 +498,26 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, ss << "auto comp = [this](ArrayItemIndexS x, ArrayItemIndexS y) {" << GetCompFunction_(0, projected, sort_key_index_list, key_field_list, projected_types, sort_directions, nulls_order) - << "};"; + << "};\n"; + return ss.str(); + } + std::string GetCompFunctionWithoutNull( + const std::vector& sort_key_index_list, + const std::shared_ptr& key_projector, + const std::vector>& projected_types, + const std::vector>& key_field_list, + const std::vector& sort_directions) { + std::stringstream ss; + bool projected; + if (key_projector) { + projected = true; + } else { + projected = false; + } + ss << "auto comp_without_null = [this](ArrayItemIndexS x, ArrayItemIndexS y) {" + << GetCompFunction_Without_Null_(0, projected, sort_key_index_list, key_field_list, + projected_types, sort_directions) + << "};\n"; return ss.str(); } std::string GetCompFunction_( @@ -515,13 +555,19 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, array + std::to_string(cur_key_id) + "_[y.array_id]->GetString(y.id)"; auto is_x_null = array + std::to_string(cur_key_id) + "_[x.array_id]->IsNull(x.id)"; auto is_y_null = array + std::to_string(cur_key_id) + "_[y.array_id]->IsNull(y.id)"; + auto x_null_count = + array + std::to_string(cur_key_id) + "_[x.array_id]->null_count() > 0"; + auto y_null_count = + array + std::to_string(cur_key_id) + "_[y.array_id]->null_count() > 0"; + auto x_null = "(" + x_null_count + " && " + is_x_null + " )"; + auto y_null = "(" + y_null_count + " && " + is_y_null + " )"; auto is_x_nan = "std::isnan(" + x_num_value + ")"; auto is_y_nan = "std::isnan(" + y_num_value + ")"; // Multiple keys sorting w/ nulls first/last is supported. std::stringstream ss; // We need to determine the position of nulls. - ss << "if (" << is_x_null << ") {\n"; + ss << "if (" << x_null << ") {\n"; // If value accessed from x is null, return true to make nulls first. if (nulls_first) { ss << "return true;\n}"; @@ -529,7 +575,7 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, ss << "return false;\n}"; } // If value accessed from y is null, return false to make nulls first. - ss << " else if (" << is_y_null << ") {\n"; + ss << " else if (" << y_null << ") {\n"; if (nulls_first) { ss << "return false;\n}"; } else { @@ -578,17 +624,17 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, // clear the contents of stringstream ss.str(std::string()); if (data_type->id() == arrow::Type::STRING) { - ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << x_str_value + ss << "if ((" << x_null << " && " << y_null << ") || (" << x_str_value << " == " << y_str_value << ")) {"; } else { if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE || data_type->id() == arrow::Type::FLOAT)) { // need to check NaN - ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << is_x_nan + ss << "if ((" << x_null << " && " << y_null << ") || (" << is_x_nan << " && " << is_y_nan << ") || (" << x_num_value << " == " << y_num_value << ")) {"; } else { - ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << x_num_value + ss << "if ((" << x_null << " && " << y_null << ") || (" << x_num_value << " == " << y_num_value << ")) {"; } } @@ -597,6 +643,104 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, << "} else { " << comp_str << "}"; return ss.str(); } + std::string GetCompFunction_Without_Null_( + int cur_key_index, bool projected, const std::vector& sort_key_index_list, + const std::vector>& key_field_list, + const std::vector>& projected_types, + const std::vector& sort_directions) { + std::string comp_str; + int cur_key_id; + auto field = key_field_list[cur_key_index]; + bool asc = sort_directions[cur_key_index]; + std::shared_ptr data_type; + std::string array; + // if projected, use projected batch to compare, and use projected type + if (projected) { + array = "projected_"; + data_type = projected_types[cur_key_index]; + // use the index of projected key + cur_key_id = cur_key_index; + } else { + array = "cached_"; + data_type = field->type(); + // use the key_id + cur_key_id = sort_key_index_list[cur_key_index]; + } + + auto x_num_value = + array + std::to_string(cur_key_id) + "_[x.array_id]->GetView(x.id)"; + auto x_str_value = + array + std::to_string(cur_key_id) + "_[x.array_id]->GetString(x.id)"; + auto y_num_value = + array + std::to_string(cur_key_id) + "_[y.array_id]->GetView(y.id)"; + auto y_str_value = + array + std::to_string(cur_key_id) + "_[y.array_id]->GetString(y.id)"; + auto is_x_nan = "std::isnan(" + x_num_value + ")"; + auto is_y_nan = "std::isnan(" + y_num_value + ")"; + + // Multiple keys sorting w/ nulls first/last is supported. + std::stringstream ss; + // If datatype is floating, we need to do partition for NaN if NaN check is enabled + if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE || + data_type->id() == arrow::Type::FLOAT)) { + ss << "if (" << is_x_nan << ") {\n"; + if (asc) { + ss << "return false;\n}"; + } else { + ss << "return true;\n}"; + } + ss << "else if (" << is_y_nan << ") {\n"; + if (asc) { + ss << "return true;\n}"; + } else { + ss << "return false;\n}"; + } + // If values accessed from x and y are both not nan + ss << " else {\n"; + } + + // Multiple keys sorting w/ different ordering is supported. + // For string type of data, GetString should be used instead of GetView. + if (asc) { + if (data_type->id() == arrow::Type::STRING) { + ss << "return " << x_str_value << " < " << y_str_value << ";\n"; + } else { + ss << "return " << x_num_value << " < " << y_num_value << ";\n"; + } + } else { + if (data_type->id() == arrow::Type::STRING) { + ss << "return " << x_str_value << " > " << y_str_value << ";\n"; + } else { + ss << "return " << x_num_value << " > " << y_num_value << ";\n"; + } + } + if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE || + data_type->id() == arrow::Type::FLOAT)) { + ss << "}" << std::endl; + } + comp_str = ss.str(); + if ((cur_key_index + 1) == sort_key_index_list.size()) { + return comp_str; + } + // clear the contents of stringstream + ss.str(std::string()); + if (data_type->id() == arrow::Type::STRING) { + ss << "if (" << x_str_value << " == " << y_str_value << ") {"; + } else { + if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE || + data_type->id() == arrow::Type::FLOAT)) { + // need to check NaN + ss << "if ((" << is_x_nan << " && " << is_y_nan << ") || (" + << x_num_value << " == " << y_num_value << ")) {"; + } else { + ss << "if (" << x_num_value << " == " << y_num_value << ") {"; + } + } + ss << GetCompFunction_Without_Null_(cur_key_index + 1, projected, sort_key_index_list, + key_field_list, projected_types, sort_directions) + << "} else { " << comp_str << "}"; + return ss.str(); + } std::string GetPreSortValid() { if (nulls_first_) { return R"( @@ -620,9 +764,12 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx, } } std::string GetSortFunction() { - return "gfx::timsort(indices_begin, indices_begin + " - "items_total_, " - "comp);"; + std::stringstream ss; + ss << "if (has_null_) {\n" + << "gfx::timsort(indices_begin, indices_begin + items_total_, comp);} else {\n" + << "gfx::timsort(indices_begin, indices_begin + items_total_, comp_without_null);}" + << std::endl; + return ss.str(); } std::string GetMakeResultIter(int shuffle_size) { std::stringstream ss;