Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-153] Fix incorrect quries after enabled Decimal (#154)
Browse files Browse the repository at this point in the history
* Fix hash_relation incorrect cache index return

Signed-off-by: Chendi Xue <[email protected]>

* Some typo fix in HashMap

Signed-off-by: Chendi Xue <[email protected]>

* Fix a type bug in Aggregate Action, cache_valicity is not align with data

Signed-off-by: Chendi Xue <[email protected]>

* use multiply res type for unscaledValue

Signed-off-by: Chendi Xue <[email protected]>

* Fix hash_relation incorrect cache index return

Signed-off-by: Chendi Xue <[email protected]>

* Some typo fix in HashMap

Signed-off-by: Chendi Xue <[email protected]>

* Fix a type bug in Aggregate Action, cache_valicity is not align with data

Signed-off-by: Chendi Xue <[email protected]>

* use multiply res type for unscaledValue

Signed-off-by: Chendi Xue <[email protected]>

* support NaN comparison in wscg

* rounddown when cast decimal to int

Signed-off-by: Chendi Xue <[email protected]>

Co-authored-by: Rui Mo <[email protected]>
  • Loading branch information
xuechendi and rui-mo authored Mar 12, 2021
1 parent 38b8885 commit ddbac40
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow)
DecimalType(dataType.precision, dataType.scale)
val resType = CodeGeneration.getResultType(newDataType)
var function = "castDECIMAL"
if(nullOnOverflow) {
if (nullOnOverflow) {
function = "castDECIMALNullOnOverflow"
}
val funcNode =
Expand Down Expand Up @@ -419,9 +419,14 @@ class ColumnarCast(
} else if (dataType == IntegerType) {
val funcNode = child.dataType match {
case d: DecimalType =>
val half_node = TreeBuilder.makeDecimalLiteral("0.5", 2, 1)
val round_down_node = TreeBuilder.makeFunction(
"subtract",
Lists.newArrayList(child_node, half_node),
childType)
val long_node = TreeBuilder.makeFunction(
"castBIGINT",
Lists.newArrayList(child_node),
Lists.newArrayList(round_down_node),
new ArrowType.Int(64, true))
TreeBuilder.makeFunction("castINT", Lists.newArrayList(long_node), resultType)
case other =>
Expand Down Expand Up @@ -491,11 +496,15 @@ class ColumnarUnscaledValue(child: Expression, original: Expression)
}
val childDataType = child.dataType.asInstanceOf[DecimalType]
val m = ConverterUtils.powerOfTen(childDataType.scale)
val multiplyType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY,
childType.asInstanceOf[ArrowType.Decimal],
new ArrowType.Decimal((m._2).toInt, (m._3).toInt, 128))
val increaseScaleNode =
TreeBuilder.makeFunction(
"multiply",
Lists.newArrayList(child_node, TreeBuilder.makeDecimalLiteral(m._1, m._2, m._3)),
childType)
multiplyType)
val funcNode =
TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(increaseScaleNode), resultType)
(funcNode, resultType)
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/codegen/arrow_compute/ext/actions_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class UniqueAction : public ActionBase {
builder_->Reset();
length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < length; i++) {
if (cache_validity_[i]) {
if (cache_validity_[offset + i]) {
if (!null_flag_[offset + i]) {
builder_->Append(cache_[offset + i]);
} else {
Expand Down Expand Up @@ -1727,7 +1727,7 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
}
cache_sum_.resize(max_group_id, 0);
cache_count_.resize(max_group_id, 0);
cache_validity_.resize(max_group_id + 1, false);
cache_validity_.resize(max_group_id, false);
return arrow::Status::OK();
}

Expand Down Expand Up @@ -1799,7 +1799,7 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
count_builder_->Reset();
length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < length; i++) {
if (cache_validity_[i]) {
if (cache_validity_[offset + i]) {
RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i]));
RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i]));
} else {
Expand Down Expand Up @@ -1907,7 +1907,7 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
}
cache_sum_.resize(max_group_id, 0);
cache_count_.resize(max_group_id, 0);
cache_validity_.resize(max_group_id + 1, false);
cache_validity_.resize(max_group_id, false);
return arrow::Status::OK();
}

Expand Down Expand Up @@ -1972,7 +1972,7 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
count_builder_->Reset();
length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < length; i++) {
if (cache_validity_[i]) {
if (cache_validity_[offset + i]) {
RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i]));
RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i]));
} else {
Expand Down Expand Up @@ -2157,7 +2157,7 @@ class SumCountMergeAction<DataType, CType, ResDataType, ResCType,
count_builder_->Reset();
length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < length; i++) {
if (cache_validity_[i]) {
if (cache_validity_[offset + i]) {
RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i]));
RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i]));
} else {
Expand Down Expand Up @@ -2332,7 +2332,7 @@ class SumCountMergeAction<DataType, CType, ResDataType, ResCType,
count_builder_->Reset();
length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < length; i++) {
if (cache_validity_[i]) {
if (cache_validity_[offset + i]) {
RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i]));
RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i]));
} else {
Expand Down
27 changes: 17 additions & 10 deletions cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("less_than_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " < " +
child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "less_than_with_nan(" + child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("greater_than") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " +
child_visitor_list[1]->GetResult() + ")";
Expand All @@ -105,15 +106,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("greater_than_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " +
child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "greater_than_with_nan(" + child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("less_than_or_equal_to") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" <= " + child_visitor_list[1]->GetResult() + ")";
Expand All @@ -125,15 +127,17 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("less_than_or_equal_to_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" <= " + child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "less_than_or_equal_to_with_nan("
+ child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("greater_than_or_equal_to") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" >= " + child_visitor_list[1]->GetResult() + ")";
Expand All @@ -145,15 +149,17 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("greater_than_or_equal_to_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" >= " + child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "greater_than_or_equal_to_with_nan("
+ child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("equal") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" == " + child_visitor_list[1]->GetResult() + ")";
Expand All @@ -165,15 +171,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
}
codes_str_ = ss.str();
} else if (func_name.compare("equal_with_nan") == 0) {
real_codes_str_ = "(" + child_visitor_list[0]->GetResult() +
" == " + child_visitor_list[1]->GetResult() + ")";
real_codes_str_ = "equal_with_nan(" + child_visitor_list[0]->GetResult()
+ ", " + child_visitor_list[1]->GetResult() + ")";
real_validity_str_ = CombineValidity(
{child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()});
ss << real_validity_str_ << " && " << real_codes_str_;
for (int i = 0; i < 2; i++) {
prepare_str_ += child_visitor_list[i]->GetPrepare();
}
codes_str_ = ss.str();
header_list_.push_back(R"(#include "precompile/gandiva.h")");
} else if (func_name.compare("not") == 0) {
std::string check_validity;
if (child_visitor_list[0]->GetPreCheck() != "") {
Expand Down
10 changes: 7 additions & 3 deletions cpp/src/codegen/common/hash_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,18 @@ class HashRelation {
if (hash_table_ == nullptr) {
throw std::runtime_error("HashRelation Get failed, hash_table is null.");
}
if (*(CType*)recent_cached_key_ == payload) return 0;
*(CType*)recent_cached_key_ = payload;
if (sizeof(payload) <= 8) {
if (*(CType*)recent_cached_key_ == payload) return recent_cached_key_probe_res_;
*(CType*)recent_cached_key_ = payload;
}
int32_t v = hash32(payload, true);
auto res = safeLookup(hash_table_, payload, v, &arrayid_list_);
if (res == -1) {
arrayid_list_.clear();
recent_cached_key_probe_res_ = -1;
return -1;
}

recent_cached_key_probe_res_ = 0;
return 0;
}

Expand Down Expand Up @@ -419,6 +422,7 @@ class HashRelation {
std::vector<ArrayItemIndex> arrayid_list_;
int key_size_;
char recent_cached_key_[8] = {0};
int recent_cached_key_probe_res_ = -1;

arrow::Status Insert(int32_t v, std::shared_ptr<UnsafeRow> payload, uint32_t array_id,
uint32_t id) {
Expand Down
68 changes: 68 additions & 0 deletions cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,71 @@ arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision,
}
return arrow::Decimal128(out);
}

// A comparison with a NaN always returns false even when comparing with itself.
// To get the same result as spark, we can regard NaN as big as Infinity when
// doing comparison.
bool less_than_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return false;
} else if (left_is_nan) {
return false;
} else if (right_is_nan) {
return true;
}
return left < right;
}

bool greater_than_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return false;
} else if (left_is_nan) {
return true;
} else if (right_is_nan) {
return false;
}
return left > right;
}

bool less_than_or_equal_to_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return true;
} else if (left_is_nan) {
return false;
} else if (right_is_nan) {
return true;
}
return left <= right;
}

bool greater_than_or_equal_to_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return true;
} else if (left_is_nan) {
return true;
} else if (right_is_nan) {
return false;
}
return left >= right;
}

bool equal_with_nan(double left, double right) {
bool left_is_nan = std::isnan(left);
bool right_is_nan = std::isnan(right);
if (left_is_nan && right_is_nan) {
return true;
} else if (left_is_nan) {
return false;
} else if (right_is_nan) {
return false;
}
return left == right;
}
25 changes: 25 additions & 0 deletions cpp/src/tests/arrow_compute_test_precompile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,30 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
ASSERT_EQ(res, arrow::Decimal128("13780.2495094037"));
}

TEST(TestArrowCompute, ArithmeticComparisonTest) {
double v1 = std::numeric_limits<double>::quiet_NaN();
double v2 = 1.0;
bool res = less_than_with_nan(v1, v2);
ASSERT_EQ(res, false);
res = less_than_with_nan(v1, v1);
ASSERT_EQ(res, false);
res = less_than_or_equal_to_with_nan(v1, v2);
ASSERT_EQ(res, false);
res = less_than_or_equal_to_with_nan(v1, v1);
ASSERT_EQ(res, true);
res = greater_than_with_nan(v1, v2);
ASSERT_EQ(res, true);
res = greater_than_with_nan(v1, v1);
ASSERT_EQ(res, false);
res = greater_than_or_equal_to_with_nan(v1, v2);
ASSERT_EQ(res, true);
res = greater_than_or_equal_to_with_nan(v1, v1);
ASSERT_EQ(res, true);
res = equal_with_nan(v1, v2);
ASSERT_EQ(res, false);
res = equal_with_nan(v1, v1);
ASSERT_EQ(res, true);
}

} // namespace codegen
} // namespace sparkcolumnarplugin
8 changes: 4 additions & 4 deletions cpp/src/third_party/row_wise_memory/hashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ static inline bool append(unsafeHashMap* hashMap, UnsafeRow* keyRow, int hashVal
*
* return should be a flag of succession of the append.
**/
template <typename CType, typename std::enable_if_t<
!std::is_same<CType, arrow::Decimal128>::value>* = nullptr>
template <typename CType,
typename std::enable_if_t<is_number_alike<CType>::value>* = nullptr>
static inline bool append(unsafeHashMap* hashMap, CType keyRow, int hashVal, char* value,
size_t value_size) {
assert(hashMap->keyArray != NULL);
Expand All @@ -775,7 +775,7 @@ static inline bool append(unsafeHashMap* hashMap, CType keyRow, int hashVal, cha
char* base = hashMap->bytesMap;
int klen = 0;
const int vlen = value_size;
const int recordLength = 4 + +klen + vlen + 4;
const int recordLength = 4 + klen + vlen + 4;
char* record = nullptr;

int keySizeInBytes = hashMap->bytesInKeyArray;
Expand Down Expand Up @@ -889,7 +889,7 @@ static inline bool append(unsafeHashMap* hashMap, CType keyRow, int hashVal, cha
char* base = hashMap->bytesMap;
int klen = 0;
const int vlen = value_size;
const int recordLength = 4 + +klen + vlen + 4;
const int recordLength = 4 + klen + vlen + 4;
char* record = nullptr;

int keySizeInBytes = hashMap->bytesInKeyArray;
Expand Down

0 comments on commit ddbac40

Please sign in to comment.