Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-97] remove isnull when null count is zero #98

Merged
merged 1 commit into from
Feb 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 158 additions & 11 deletions cpp/src/codegen/arrow_compute/ext/sort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,14 @@ class SortArraysToIndicesKernel::Impl {
indice++;
}
std::string cached_insert_str = GetCachedInsert(
shuffle_typed_codegen_list.size(), projected_types_.size(), key_projector_);
shuffle_typed_codegen_list.size(), projected_types_.size(), key_projector_,
key_index_list_);
std::string comp_func_str =
GetCompFunction(key_index_list_, key_projector_, projected_types_,
key_field_list_, sort_directions_, nulls_order_);
std::string comp_func_str_without_null =
GetCompFunctionWithoutNull(key_index_list_, key_projector_, projected_types_,
key_field_list_, sort_directions_);

std::string pre_sort_valid_str = GetPreSortValid();

Expand Down Expand Up @@ -329,6 +333,7 @@ class TypedSorterImpl : public CodeGenBase {
// we should support nulls first and nulls last here
// we should also support desc and asc here
)" + comp_func_str +
comp_func_str_without_null +
R"(
// initiate buffer for all arrays
std::shared_ptr<arrow::Buffer> indices_buf;
Expand Down Expand Up @@ -375,6 +380,7 @@ class TypedSorterImpl : public CodeGenBase {
uint64_t num_batches_ = 0;
uint64_t items_total_ = 0;
uint64_t nulls_total_ = 0;
bool has_null_ = false;

class SortRelationResultIterator : public ResultIterator<SortRelation> {
public:
Expand Down Expand Up @@ -447,7 +453,8 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
)";
}
std::string GetCachedInsert(int shuffle_size, int projected_size,
const std::shared_ptr<gandiva::Projector>& key_projector) {
const std::shared_ptr<gandiva::Projector>& key_projector,
const std::vector<int>& sort_key_index_list) {
std::stringstream ss;
for (int i = 0; i < shuffle_size; i++) {
ss << "cached_" << i << "_.push_back(std::make_shared<ArrayType_" << i << ">(in["
Expand All @@ -459,6 +466,20 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
<< ">(projected_batch[" << i << "]));" << std::endl;
}
}
if (key_projector) {
for (int i = 0; i < projected_size; i++) {
ss << "if (!has_null_ && projected_" << i
<< "_[projected_0_.size() - 1]->null_count() > 0) { " << "has_null_ = true;}"
<< std::endl;
}
} else {
for (int i = 0; i < sort_key_index_list.size(); i++) {
int key_id = sort_key_index_list[i];
ss << "if (!has_null_ && cached_" << key_id << "_[cached_" << key_id
<< "_.size() - 1]->null_count() > 0) {"
<< "has_null_ = true;}" << std::endl;
}
}
return ss.str();
}
std::string GetCompFunction(
Expand All @@ -477,7 +498,26 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
ss << "auto comp = [this](ArrayItemIndexS x, ArrayItemIndexS y) {"
<< GetCompFunction_(0, projected, sort_key_index_list, key_field_list,
projected_types, sort_directions, nulls_order)
<< "};";
<< "};\n";
return ss.str();
}
std::string GetCompFunctionWithoutNull(
const std::vector<int>& sort_key_index_list,
const std::shared_ptr<gandiva::Projector>& key_projector,
const std::vector<std::shared_ptr<arrow::DataType>>& projected_types,
const std::vector<std::shared_ptr<arrow::Field>>& key_field_list,
const std::vector<bool>& sort_directions) {
std::stringstream ss;
bool projected;
if (key_projector) {
projected = true;
} else {
projected = false;
}
ss << "auto comp_without_null = [this](ArrayItemIndexS x, ArrayItemIndexS y) {"
<< GetCompFunction_Without_Null_(0, projected, sort_key_index_list, key_field_list,
projected_types, sort_directions)
<< "};\n";
return ss.str();
}
std::string GetCompFunction_(
Expand Down Expand Up @@ -515,21 +555,27 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
array + std::to_string(cur_key_id) + "_[y.array_id]->GetString(y.id)";
auto is_x_null = array + std::to_string(cur_key_id) + "_[x.array_id]->IsNull(x.id)";
auto is_y_null = array + std::to_string(cur_key_id) + "_[y.array_id]->IsNull(y.id)";
auto x_null_count =
array + std::to_string(cur_key_id) + "_[x.array_id]->null_count() > 0";
auto y_null_count =
array + std::to_string(cur_key_id) + "_[y.array_id]->null_count() > 0";
auto x_null = "(" + x_null_count + " && " + is_x_null + " )";
auto y_null = "(" + y_null_count + " && " + is_y_null + " )";
auto is_x_nan = "std::isnan(" + x_num_value + ")";
auto is_y_nan = "std::isnan(" + y_num_value + ")";

// Multiple keys sorting w/ nulls first/last is supported.
std::stringstream ss;
// We need to determine the position of nulls.
ss << "if (" << is_x_null << ") {\n";
ss << "if (" << x_null << ") {\n";
// If value accessed from x is null, return true to make nulls first.
if (nulls_first) {
ss << "return true;\n}";
} else {
ss << "return false;\n}";
}
// If value accessed from y is null, return false to make nulls first.
ss << " else if (" << is_y_null << ") {\n";
ss << " else if (" << y_null << ") {\n";
if (nulls_first) {
ss << "return false;\n}";
} else {
Expand Down Expand Up @@ -578,17 +624,17 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
// clear the contents of stringstream
ss.str(std::string());
if (data_type->id() == arrow::Type::STRING) {
ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << x_str_value
ss << "if ((" << x_null << " && " << y_null << ") || (" << x_str_value
<< " == " << y_str_value << ")) {";
} else {
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
// need to check NaN
ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << is_x_nan
ss << "if ((" << x_null << " && " << y_null << ") || (" << is_x_nan
<< " && " << is_y_nan << ") || (" << x_num_value << " == " << y_num_value
<< ")) {";
} else {
ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << x_num_value
ss << "if ((" << x_null << " && " << y_null << ") || (" << x_num_value
<< " == " << y_num_value << ")) {";
}
}
Expand All @@ -597,6 +643,104 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
<< "} else { " << comp_str << "}";
return ss.str();
}
std::string GetCompFunction_Without_Null_(
int cur_key_index, bool projected, const std::vector<int>& sort_key_index_list,
const std::vector<std::shared_ptr<arrow::Field>>& key_field_list,
const std::vector<std::shared_ptr<arrow::DataType>>& projected_types,
const std::vector<bool>& sort_directions) {
std::string comp_str;
int cur_key_id;
auto field = key_field_list[cur_key_index];
bool asc = sort_directions[cur_key_index];
std::shared_ptr<arrow::DataType> data_type;
std::string array;
// if projected, use projected batch to compare, and use projected type
if (projected) {
array = "projected_";
data_type = projected_types[cur_key_index];
// use the index of projected key
cur_key_id = cur_key_index;
} else {
array = "cached_";
data_type = field->type();
// use the key_id
cur_key_id = sort_key_index_list[cur_key_index];
}

auto x_num_value =
array + std::to_string(cur_key_id) + "_[x.array_id]->GetView(x.id)";
auto x_str_value =
array + std::to_string(cur_key_id) + "_[x.array_id]->GetString(x.id)";
auto y_num_value =
array + std::to_string(cur_key_id) + "_[y.array_id]->GetView(y.id)";
auto y_str_value =
array + std::to_string(cur_key_id) + "_[y.array_id]->GetString(y.id)";
auto is_x_nan = "std::isnan(" + x_num_value + ")";
auto is_y_nan = "std::isnan(" + y_num_value + ")";

// Multiple keys sorting w/ nulls first/last is supported.
std::stringstream ss;
// If datatype is floating, we need to do partition for NaN if NaN check is enabled
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
ss << "if (" << is_x_nan << ") {\n";
if (asc) {
ss << "return false;\n}";
} else {
ss << "return true;\n}";
}
ss << "else if (" << is_y_nan << ") {\n";
if (asc) {
ss << "return true;\n}";
} else {
ss << "return false;\n}";
}
// If values accessed from x and y are both not nan
ss << " else {\n";
}

// Multiple keys sorting w/ different ordering is supported.
// For string type of data, GetString should be used instead of GetView.
if (asc) {
if (data_type->id() == arrow::Type::STRING) {
ss << "return " << x_str_value << " < " << y_str_value << ";\n";
} else {
ss << "return " << x_num_value << " < " << y_num_value << ";\n";
}
} else {
if (data_type->id() == arrow::Type::STRING) {
ss << "return " << x_str_value << " > " << y_str_value << ";\n";
} else {
ss << "return " << x_num_value << " > " << y_num_value << ";\n";
}
}
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
ss << "}" << std::endl;
}
comp_str = ss.str();
if ((cur_key_index + 1) == sort_key_index_list.size()) {
return comp_str;
}
// clear the contents of stringstream
ss.str(std::string());
if (data_type->id() == arrow::Type::STRING) {
ss << "if (" << x_str_value << " == " << y_str_value << ") {";
} else {
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
// need to check NaN
ss << "if ((" << is_x_nan << " && " << is_y_nan << ") || ("
<< x_num_value << " == " << y_num_value << ")) {";
} else {
ss << "if (" << x_num_value << " == " << y_num_value << ") {";
}
}
ss << GetCompFunction_Without_Null_(cur_key_index + 1, projected, sort_key_index_list,
key_field_list, projected_types, sort_directions)
<< "} else { " << comp_str << "}";
return ss.str();
}
std::string GetPreSortValid() {
if (nulls_first_) {
return R"(
Expand All @@ -620,9 +764,12 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
}
}
std::string GetSortFunction() {
return "gfx::timsort(indices_begin, indices_begin + "
"items_total_, "
"comp);";
std::stringstream ss;
ss << "if (has_null_) {\n"
<< "gfx::timsort(indices_begin, indices_begin + items_total_, comp);} else {\n"
<< "gfx::timsort(indices_begin, indices_begin + items_total_, comp_without_null);}"
<< std::endl;
return ss.str();
}
std::string GetMakeResultIter(int shuffle_size) {
std::stringstream ss;
Expand Down