diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala index c35da971f..c46809b42 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala @@ -282,7 +282,7 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow) DecimalType(dataType.precision, dataType.scale) val resType = CodeGeneration.getResultType(newDataType) var function = "castDECIMAL" - if(nullOnOverflow) { + if (nullOnOverflow) { function = "castDECIMALNullOnOverflow" } val funcNode = @@ -419,9 +419,14 @@ class ColumnarCast( } else if (dataType == IntegerType) { val funcNode = child.dataType match { case d: DecimalType => + val half_node = TreeBuilder.makeDecimalLiteral("0.5", 2, 1) + val round_down_node = TreeBuilder.makeFunction( + "subtract", + Lists.newArrayList(child_node, half_node), + childType) val long_node = TreeBuilder.makeFunction( "castBIGINT", - Lists.newArrayList(child_node), + Lists.newArrayList(round_down_node), new ArrowType.Int(64, true)) TreeBuilder.makeFunction("castINT", Lists.newArrayList(long_node), resultType) case other => @@ -491,11 +496,15 @@ class ColumnarUnscaledValue(child: Expression, original: Expression) } val childDataType = child.dataType.asInstanceOf[DecimalType] val m = ConverterUtils.powerOfTen(childDataType.scale) + val multiplyType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.MULTIPLY, + childType.asInstanceOf[ArrowType.Decimal], + new ArrowType.Decimal((m._2).toInt, (m._3).toInt, 128)) val increaseScaleNode = TreeBuilder.makeFunction( "multiply", Lists.newArrayList(child_node, TreeBuilder.makeDecimalLiteral(m._1, m._2, m._3)), - childType) + multiplyType) val funcNode = TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(increaseScaleNode), resultType) (funcNode, resultType) diff --git a/cpp/src/codegen/arrow_compute/ext/actions_impl.cc b/cpp/src/codegen/arrow_compute/ext/actions_impl.cc index 9c42fb5e6..e232f2b84 100644 --- a/cpp/src/codegen/arrow_compute/ext/actions_impl.cc +++ b/cpp/src/codegen/arrow_compute/ext/actions_impl.cc @@ -240,7 +240,7 @@ class UniqueAction : public ActionBase { builder_->Reset(); length = (offset + length) > length_ ? (length_ - offset) : length; for (uint64_t i = 0; i < length; i++) { - if (cache_validity_[i]) { + if (cache_validity_[offset + i]) { if (!null_flag_[offset + i]) { builder_->Append(cache_[offset + i]); } else { @@ -1727,7 +1727,7 @@ class SumCountActionReset(); length = (offset + length) > length_ ? (length_ - offset) : length; for (uint64_t i = 0; i < length; i++) { - if (cache_validity_[i]) { + if (cache_validity_[offset + i]) { RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i])); RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i])); } else { @@ -1907,7 +1907,7 @@ class SumCountActionReset(); length = (offset + length) > length_ ? (length_ - offset) : length; for (uint64_t i = 0; i < length; i++) { - if (cache_validity_[i]) { + if (cache_validity_[offset + i]) { RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i])); RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i])); } else { @@ -2157,7 +2157,7 @@ class SumCountMergeActionReset(); length = (offset + length) > length_ ? (length_ - offset) : length; for (uint64_t i = 0; i < length; i++) { - if (cache_validity_[i]) { + if (cache_validity_[offset + i]) { RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i])); RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i])); } else { @@ -2332,7 +2332,7 @@ class SumCountMergeActionReset(); length = (offset + length) > length_ ? (length_ - offset) : length; for (uint64_t i = 0; i < length; i++) { - if (cache_validity_[i]) { + if (cache_validity_[offset + i]) { RETURN_NOT_OK(sum_builder_->Append(cache_sum_[offset + i])); RETURN_NOT_OK(count_builder_->Append(cache_count_[offset + i])); } else { diff --git a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index c0c7e5b73..4b4bd657b 100644 --- a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -85,8 +85,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("less_than_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " < " + - child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "less_than_with_nan(" + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -94,6 +94,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("greater_than") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " + child_visitor_list[1]->GetResult() + ")"; @@ -105,8 +106,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("greater_than_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " > " + - child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "greater_than_with_nan(" + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -114,6 +115,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("less_than_or_equal_to") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " <= " + child_visitor_list[1]->GetResult() + ")"; @@ -125,8 +127,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("less_than_or_equal_to_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + - " <= " + child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "less_than_or_equal_to_with_nan(" + + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -134,6 +137,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("greater_than_or_equal_to") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " >= " + child_visitor_list[1]->GetResult() + ")"; @@ -145,8 +149,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("greater_than_or_equal_to_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + - " >= " + child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "greater_than_or_equal_to_with_nan(" + + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -154,6 +159,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("equal") == 0) { real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + " == " + child_visitor_list[1]->GetResult() + ")"; @@ -165,8 +171,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } codes_str_ = ss.str(); } else if (func_name.compare("equal_with_nan") == 0) { - real_codes_str_ = "(" + child_visitor_list[0]->GetResult() + - " == " + child_visitor_list[1]->GetResult() + ")"; + real_codes_str_ = "equal_with_nan(" + child_visitor_list[0]->GetResult() + + ", " + child_visitor_list[1]->GetResult() + ")"; real_validity_str_ = CombineValidity( {child_visitor_list[0]->GetPreCheck(), child_visitor_list[1]->GetPreCheck()}); ss << real_validity_str_ << " && " << real_codes_str_; @@ -174,6 +180,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) prepare_str_ += child_visitor_list[i]->GetPrepare(); } codes_str_ = ss.str(); + header_list_.push_back(R"(#include "precompile/gandiva.h")"); } else if (func_name.compare("not") == 0) { std::string check_validity; if (child_visitor_list[0]->GetPreCheck() != "") { diff --git a/cpp/src/codegen/common/hash_relation.h b/cpp/src/codegen/common/hash_relation.h index 03d370e59..6cb5efebe 100644 --- a/cpp/src/codegen/common/hash_relation.h +++ b/cpp/src/codegen/common/hash_relation.h @@ -305,15 +305,18 @@ class HashRelation { if (hash_table_ == nullptr) { throw std::runtime_error("HashRelation Get failed, hash_table is null."); } - if (*(CType*)recent_cached_key_ == payload) return 0; - *(CType*)recent_cached_key_ = payload; + if (sizeof(payload) <= 8) { + if (*(CType*)recent_cached_key_ == payload) return recent_cached_key_probe_res_; + *(CType*)recent_cached_key_ = payload; + } int32_t v = hash32(payload, true); auto res = safeLookup(hash_table_, payload, v, &arrayid_list_); if (res == -1) { arrayid_list_.clear(); + recent_cached_key_probe_res_ = -1; return -1; } - + recent_cached_key_probe_res_ = 0; return 0; } @@ -419,6 +422,7 @@ class HashRelation { std::vector arrayid_list_; int key_size_; char recent_cached_key_[8] = {0}; + int recent_cached_key_probe_res_ = -1; arrow::Status Insert(int32_t v, std::shared_ptr payload, uint32_t array_id, uint32_t id) { diff --git a/cpp/src/precompile/gandiva.h b/cpp/src/precompile/gandiva.h index de8e608c4..df1416138 100644 --- a/cpp/src/precompile/gandiva.h +++ b/cpp/src/precompile/gandiva.h @@ -121,3 +121,71 @@ arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision, } return arrow::Decimal128(out); } + +// A comparison with a NaN always returns false even when comparing with itself. +// To get the same result as spark, we can regard NaN as big as Infinity when +// doing comparison. +bool less_than_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return false; + } else if (left_is_nan) { + return false; + } else if (right_is_nan) { + return true; + } + return left < right; +} + +bool greater_than_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return false; + } else if (left_is_nan) { + return true; + } else if (right_is_nan) { + return false; + } + return left > right; +} + +bool less_than_or_equal_to_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return true; + } else if (left_is_nan) { + return false; + } else if (right_is_nan) { + return true; + } + return left <= right; +} + +bool greater_than_or_equal_to_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return true; + } else if (left_is_nan) { + return true; + } else if (right_is_nan) { + return false; + } + return left >= right; +} + +bool equal_with_nan(double left, double right) { + bool left_is_nan = std::isnan(left); + bool right_is_nan = std::isnan(right); + if (left_is_nan && right_is_nan) { + return true; + } else if (left_is_nan) { + return false; + } else if (right_is_nan) { + return false; + } + return left == right; +} diff --git a/cpp/src/tests/arrow_compute_test_precompile.cc b/cpp/src/tests/arrow_compute_test_precompile.cc index 8edbd0b74..4f3549749 100644 --- a/cpp/src/tests/arrow_compute_test_precompile.cc +++ b/cpp/src/tests/arrow_compute_test_precompile.cc @@ -68,5 +68,30 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) { ASSERT_EQ(res, arrow::Decimal128("13780.2495094037")); } +TEST(TestArrowCompute, ArithmeticComparisonTest) { + double v1 = std::numeric_limits::quiet_NaN(); + double v2 = 1.0; + bool res = less_than_with_nan(v1, v2); + ASSERT_EQ(res, false); + res = less_than_with_nan(v1, v1); + ASSERT_EQ(res, false); + res = less_than_or_equal_to_with_nan(v1, v2); + ASSERT_EQ(res, false); + res = less_than_or_equal_to_with_nan(v1, v1); + ASSERT_EQ(res, true); + res = greater_than_with_nan(v1, v2); + ASSERT_EQ(res, true); + res = greater_than_with_nan(v1, v1); + ASSERT_EQ(res, false); + res = greater_than_or_equal_to_with_nan(v1, v2); + ASSERT_EQ(res, true); + res = greater_than_or_equal_to_with_nan(v1, v1); + ASSERT_EQ(res, true); + res = equal_with_nan(v1, v2); + ASSERT_EQ(res, false); + res = equal_with_nan(v1, v1); + ASSERT_EQ(res, true); +} + } // namespace codegen } // namespace sparkcolumnarplugin diff --git a/cpp/src/third_party/row_wise_memory/hashMap.h b/cpp/src/third_party/row_wise_memory/hashMap.h index 3e03b0441..d221216b1 100755 --- a/cpp/src/third_party/row_wise_memory/hashMap.h +++ b/cpp/src/third_party/row_wise_memory/hashMap.h @@ -759,8 +759,8 @@ static inline bool append(unsafeHashMap* hashMap, UnsafeRow* keyRow, int hashVal * * return should be a flag of succession of the append. **/ -template ::value>* = nullptr> +template ::value>* = nullptr> static inline bool append(unsafeHashMap* hashMap, CType keyRow, int hashVal, char* value, size_t value_size) { assert(hashMap->keyArray != NULL); @@ -775,7 +775,7 @@ static inline bool append(unsafeHashMap* hashMap, CType keyRow, int hashVal, cha char* base = hashMap->bytesMap; int klen = 0; const int vlen = value_size; - const int recordLength = 4 + +klen + vlen + 4; + const int recordLength = 4 + klen + vlen + 4; char* record = nullptr; int keySizeInBytes = hashMap->bytesInKeyArray; @@ -889,7 +889,7 @@ static inline bool append(unsafeHashMap* hashMap, CType keyRow, int hashVal, cha char* base = hashMap->bytesMap; int klen = 0; const int vlen = value_size; - const int recordLength = 4 + +klen + vlen + 4; + const int recordLength = 4 + klen + vlen + 4; char* record = nullptr; int keySizeInBytes = hashMap->bytesInKeyArray;