Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
add nan support in non-codegen sort
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Feb 4, 2021
1 parent d39f680 commit b0d6ad4
Show file tree
Hide file tree
Showing 3 changed files with 446 additions and 9 deletions.
339 changes: 332 additions & 7 deletions cpp/src/codegen/arrow_compute/ext/cmp_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,321 @@ class TypedComparator {
using ArrayType = typename arrow::TypeTraits<DataType>::ArrayType;
};

template <typename DataType, typename CType>
class FloatingComparator {
public:
FloatingComparator() {}

~FloatingComparator() {}

func::function<void(int, int, int64_t, int64_t, int&)> GetCompareFunc(
const arrow::ArrayVector& arrays, bool asc, bool nulls_first, bool nan_check) {
uint64_t null_total = 0;
std::vector<std::shared_ptr<ArrayType>> typed_arrays;
for (int array_id = 0; array_id < arrays.size(); array_id++) {
null_total += arrays[array_id]->null_count();
auto typed_array = std::dynamic_pointer_cast<ArrayType>(arrays[array_id]);
typed_arrays.push_back(typed_array);
}
if (null_total == 0) {
if (asc) {
if (nan_check) {
// null_total == 0, asc, nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
bool is_left_nan = std::isnan(left);
bool is_right_nan = std::isnan(right);
if (!is_left_nan || !is_right_nan) {
if (is_left_nan) {
cmp_res = 0;
} else if (is_right_nan) {
cmp_res = 1;
} else {
if (left != right) {
cmp_res = left < right;
}
}
}
};
} else {
// null_total == 0, asc, !nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
if (left != right) {
cmp_res = left < right;
}
};
}
} else {
if (nan_check) {
// null_total == 0, desc, nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
bool is_left_nan = std::isnan(left);
bool is_right_nan = std::isnan(right);
if (!is_left_nan || !is_right_nan) {
if (is_left_nan) {
cmp_res = 1;
} else if (is_right_nan) {
cmp_res = 0;
} else {
if (left != right) {
cmp_res = left > right;
}
}
}
};
} else {
// null_total == 0, desc, !nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
if (left != right) {
cmp_res = left > right;
}
};
}
}
} else if (asc) {
if (nulls_first) {
if (nan_check) {
// nulls_first, asc, nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 1;
} else if (is_right_null) {
cmp_res = 0;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
bool is_left_nan = std::isnan(left);
bool is_right_nan = std::isnan(right);
if (!is_left_nan || !is_right_nan) {
if (is_left_nan) {
cmp_res = 0;
} else if (is_right_nan) {
cmp_res = 1;
} else {
if (left != right) {
cmp_res = left < right;
}
}
}
}
}
};
} else {
// nulls_first, asc, !nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 1;
} else if (is_right_null) {
cmp_res = 0;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
if (left != right) {
cmp_res = left < right;
}
}
}
};
}
} else {
if (nan_check) {
// nulls_last, asc, nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 0;
} else if (is_right_null) {
cmp_res = 1;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
bool is_left_nan = std::isnan(left);
bool is_right_nan = std::isnan(right);
if (!is_left_nan || !is_right_nan) {
if (is_left_nan) {
cmp_res = 0;
} else if (is_right_nan) {
cmp_res = 1;
} else {
if (left != right) {
cmp_res = left < right;
}
}
}
}
}
};
} else {
// nulls_last, asc, !nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 0;
} else if (is_right_null) {
cmp_res = 1;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
if (left != right) {
cmp_res = left < right;
}
}
}
};
}
}
} else if (nulls_first) {
if (nan_check) {
// nulls_first, desc, nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 1;
} else if (is_right_null) {
cmp_res = 0;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
bool is_left_nan = std::isnan(left);
bool is_right_nan = std::isnan(right);
if (!is_left_nan || !is_right_nan) {
if (is_left_nan) {
cmp_res = 1;
} else if (is_right_nan) {
cmp_res = 0;
} else {
if (left != right) {
cmp_res = left > right;
}
}
}
}
}
};
} else {
// nulls_first, desc, !nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 1;
} else if (is_right_null) {
cmp_res = 0;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
if (left != right) {
cmp_res = left > right;
}
}
}
};
}
} else {
if (nan_check) {
// nulls_last, desc, nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 0;
} else if (is_right_null) {
cmp_res = 1;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
bool is_left_nan = std::isnan(left);
bool is_right_nan = std::isnan(right);
if (!is_left_nan || !is_right_nan) {
if (is_left_nan) {
cmp_res = 1;
} else if (is_right_nan) {
cmp_res = 0;
} else {
if (left != right) {
cmp_res = left > right;
}
}
}
}
}
};
} else {
// nulls_last, desc, !nan_check
return [=](int left_array_id, int right_array_id,
int64_t left_id, int64_t right_id, int& cmp_res) {
bool is_left_null = typed_arrays[left_array_id]->null_count() > 0 &&
typed_arrays[left_array_id]->IsNull(left_id);
bool is_right_null = typed_arrays[right_array_id]->null_count() > 0 &&
typed_arrays[right_array_id]->IsNull(right_id);
if (!is_left_null || !is_right_null) {
if (is_left_null) {
cmp_res = 0;
} else if (is_right_null) {
cmp_res = 1;
} else {
CType left = typed_arrays[left_array_id]->GetView(left_id);
CType right = typed_arrays[right_array_id]->GetView(right_id);
if (left != right) {
cmp_res = left > right;
}
}
}
};
}
}
}

private:
using ArrayType = typename arrow::TypeTraits<DataType>::ArrayType;
};

template <typename DataType, typename CType>
class StringComparator {
public:
Expand Down Expand Up @@ -294,16 +609,15 @@ class StringComparator {
PROCESS(arrow::Int32Type) \
PROCESS(arrow::UInt64Type) \
PROCESS(arrow::Int64Type) \
PROCESS(arrow::FloatType) \
PROCESS(arrow::DoubleType) \
PROCESS(arrow::Date32Type) \
PROCESS(arrow::Date64Type)
static arrow::Status MakeCmpFunction(
const std::vector<arrow::ArrayVector>& array_vectors,
std::vector<std::shared_ptr<arrow::Field>> key_field_list,
std::vector<int> key_index_list,
std::vector<bool> sort_directions,
std::vector<bool> nulls_order,
const std::vector<std::shared_ptr<arrow::Field>>& key_field_list,
const std::vector<int>& key_index_list,
const std::vector<bool>& sort_directions,
const std::vector<bool>& nulls_order,
const bool& nan_check,
std::vector<func::function<void(int, int, int64_t, int64_t, int&)>>& cmp_functions) {
for (int i = 0; i < key_field_list.size(); i++) {
auto type = key_field_list[i]->type();
Expand All @@ -316,6 +630,16 @@ static arrow::Status MakeCmpFunction(
std::make_shared<StringComparator<arrow::StringType, std::string>>();
cmp_functions.push_back(
comparator_ptr->GetCompareFunc(col, asc, nulls_first));
} else if (type->id() == arrow::Type::DOUBLE) {
auto comparator_ptr =
std::make_shared<FloatingComparator<arrow::DoubleType, double>>();
cmp_functions.push_back(
comparator_ptr->GetCompareFunc(col, asc, nulls_first, nan_check));
} else if (type->id() == arrow::Type::FLOAT) {
auto comparator_ptr =
std::make_shared<FloatingComparator<arrow::FloatType, float>>();
cmp_functions.push_back(
comparator_ptr->GetCompareFunc(col, asc, nulls_first, nan_check));
} else {
switch (type->id()) {
#define PROCESS(InType) \
Expand All @@ -327,7 +651,8 @@ static arrow::Status MakeCmpFunction(
PROCESS_SUPPORTED_TYPES(PROCESS)
#undef PROCESS
default: {
std::cout << "MakeCmpFunction type not supported, type is " << type << std::endl;
std::cout << "MakeCmpFunction type not supported, type is "
<< type << std::endl;
} break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/codegen/arrow_compute/ext/sort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1721,11 +1721,11 @@ class SortMultiplekeyKernel : public SortArraysToIndicesKernel::Impl {
}
MakeCmpFunction(
projected_, projected_field_list_, projected_key_idx_list, sort_directions_,
nulls_order_, cmp_functions_);
nulls_order_, NaN_check_, cmp_functions_);
} else {
MakeCmpFunction(
cached_, key_field_list_, key_index_list_, sort_directions_,
nulls_order_, cmp_functions_);
nulls_order_, NaN_check_, cmp_functions_);
}
Sort(indices_begin, indices_end);
std::shared_ptr<arrow::FixedSizeBinaryType> out_type;
Expand Down
Loading

0 comments on commit b0d6ad4

Please sign in to comment.