diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc index 309f2fd5a..78cf992ad 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/codegen_node_visitor.cc @@ -298,13 +298,50 @@ arrow::Status CodeGenNodeVisitor::Visit(const gandiva::FunctionNode& node) { RETURN_NOT_OK(AppendProjectList(child_visitor_list, i)); } codes_str_ = ss.str(); - } else if (func_name.find("cast") != std::string::npos) { + } else if (func_name.find("cast") != std::string::npos && + func_name.compare("castDATE") != 0 && + func_name.compare("castDECIMAL") != 0 && + func_name.compare("castDECIMALNullOnOverflow") != 0) { ss << child_visitor_list[0]->GetResult(); check_str_ = child_visitor_list[0]->GetPreCheck(); for (int i = 0; i < 1; i++) { RETURN_NOT_OK(AppendProjectList(child_visitor_list, i)); } codes_str_ = ss.str(); + } else if (func_name.compare("castDECIMAL") == 0) { + codes_str_ = func_name + "_" + std::to_string(cur_func_id); + auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + auto decimal_type = + std::dynamic_pointer_cast(node.return_type()); + auto childNode = node.children().at(0); + if (childNode->return_type()->id() != arrow::Type::DECIMAL) { + // if not casting from Decimal + fix_ss << ", " << decimal_type->precision() << ", " << decimal_type->scale(); + } else { + // if casting from Decimal + auto childType = + std::dynamic_pointer_cast(childNode->return_type()); + fix_ss << ", " << childType->precision() << ", " << childType->scale() << ", " + << decimal_type->precision() << ", " << decimal_type->scale(); + } + std::stringstream prepare_ss; + prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" + << std::endl; + prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() + << ";" << std::endl; + prepare_ss << "if (" << validity << ") {" << std::endl; + prepare_ss << codes_str_ << " = " << func_name << "(" + << child_visitor_list[0]->GetResult() << fix_ss.str() << ");" + << std::endl; + prepare_ss << "}" << std::endl; + + for (int i = 0; i < 1; i++) { + prepare_str_ += child_visitor_list[i]->GetPrepare(); + } + prepare_str_ += prepare_ss.str(); + check_str_ = validity; + } else if (func_name.compare("add") == 0) { codes_str_ = "add_" + std::to_string(cur_func_id); auto validity = "add_validity_" + std::to_string(cur_func_id);