diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index d025a9c573d97..8a3d284e01350 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -71,18 +71,39 @@ std::vector GetStringFunctionRegistry() { NativeFunction("castINT", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull, "gdv_fn_castINT_utf8", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // return null if fail to cast + NativeFunction("castINTOrNull", {}, DataTypeVector{utf8()}, int32(), kResultNullInternal, + "gdv_fn_castINT_or_null_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("castBIGINT", {}, DataTypeVector{utf8()}, int64(), kResultNullIfNull, "gdv_fn_castBIGINT_utf8", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // return null if fail to cast + NativeFunction("castBIGINTOrNull", {}, DataTypeVector{utf8()}, int64(), kResultNullInternal, + "gdv_fn_castBIGINT_or_null_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("castFLOAT4", {}, DataTypeVector{utf8()}, float32(), kResultNullIfNull, "gdv_fn_castFLOAT4_utf8", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + // return null if fail to cast + NativeFunction("castFLOAT4OrNull", {}, DataTypeVector{utf8()}, float32(), + kResultNullInternal, "gdv_fn_castFLOAT4_or_null_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(), kResultNullIfNull, "gdv_fn_castFLOAT8_utf8", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // return null if fail to cast + NativeFunction("castFLOAT8OrNull", {}, DataTypeVector{utf8()}, float64(), + kResultNullInternal, "gdv_fn_castFLOAT8_or_null_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("castVARCHAR", {}, DataTypeVector{int8(), int64()}, utf8(), kResultNullIfNull, "castVARCHAR_int8_int64", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 6756af5e76d9a..159c7c1cc939f 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -308,6 +308,40 @@ CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8) #undef CAST_NUMERIC_FROM_STRING +#define CAST_NUMERIC_OR_NULL_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ + GANDIVA_EXPORT \ + OUT_TYPE gdv_fn_cast##TYPE_NAME##_or_null_utf8(int64_t context, const char* data, \ + int32_t len, bool in_valid, bool* out_valid) { \ + OUT_TYPE val = 0; \ + *out_valid = true; \ + if (!in_valid) { \ + *out_valid = false; \ + return val; \ + } \ + /* trim leading and trailing spaces */ \ + int32_t trimmed_len; \ + int32_t start = 0, end = len - 1; \ + while (start <= end && data[start] == ' ') { \ + ++start; \ + } \ + while (end >= start && data[end] == ' ') { \ + --end; \ + } \ + trimmed_len = end - start + 1; \ + const char* trimmed_data = data + start; \ + if (!arrow::internal::ParseValue(trimmed_data, trimmed_len, &val)) { \ + *out_valid = false; \ + } \ + return val; \ + } + +CAST_NUMERIC_OR_NULL_FROM_STRING(int32_t, arrow::Int32Type, INT) +CAST_NUMERIC_OR_NULL_FROM_STRING(int64_t, arrow::Int64Type, BIGINT) +CAST_NUMERIC_OR_NULL_FROM_STRING(float, arrow::FloatType, FLOAT4) +CAST_NUMERIC_OR_NULL_FROM_STRING(double, arrow::DoubleType, FLOAT8) + +#undef CAST_NUMERIC_OR_NULL_FROM_STRING + #define GDV_FN_CAST_VARCHAR_INTEGER(IN_TYPE, ARROW_TYPE) \ GANDIVA_EXPORT \ const char* gdv_fn_castVARCHAR_##IN_TYPE##_int64(int64_t context, gdv_##IN_TYPE value, \ @@ -534,12 +568,30 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_castINT_utf8", types->i32_type(), args, reinterpret_cast(gdv_fn_castINT_utf8)); + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int32_t lenr + types->i1_type(), // bool in2_validity + types->ptr_type(types->i8_type())}; // bool* out_valid + + engine->AddGlobalMappingForFunc("gdv_fn_castINT_or_null_utf8", types->i32_type(), args, + reinterpret_cast(gdv_fn_castINT_or_null_utf8)); + args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type()}; // int32_t lenr engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_utf8", types->i64_type(), args, reinterpret_cast(gdv_fn_castBIGINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int32_t lenr + types->i1_type(), // bool in2_validity + types->ptr_type(types->i8_type())}; // bool* out_valid + + engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_or_null_utf8", types->i64_type(), args, + reinterpret_cast(gdv_fn_castBIGINT_or_null_utf8)); args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data @@ -548,12 +600,30 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_utf8", types->float_type(), args, reinterpret_cast(gdv_fn_castFLOAT4_utf8)); + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int32_t lenr + types->i1_type(), // bool in2_validity + types->ptr_type(types->i8_type())}; // bool* out_valid + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_or_null_utf8", types->float_type(), args, + reinterpret_cast(gdv_fn_castFLOAT4_or_null_utf8)); + args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type()}; // int32_t lenr engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args, reinterpret_cast(gdv_fn_castFLOAT8_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int32_t lenr + types->i1_type(), // bool in2_validity + types->ptr_type(types->i8_type())}; // bool* out_valid + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_or_null_utf8", types->double_type(), args, + reinterpret_cast(gdv_fn_castFLOAT8_or_null_utf8)); // gdv_fn_castVARCHAR_int32_int64 args = {types->i64_type(), // int64_t execution_context diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 0fd124d75933d..b3ffd2b61d138 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -88,15 +88,27 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, GANDIVA_EXPORT int32_t gdv_fn_castINT_utf8(int64_t context, const char* data, int32_t data_len); +GANDIVA_EXPORT +int32_t gdv_fn_castINT_or_null_utf8(int64_t context, const char* data, int32_t data_len, bool in_valid, bool* out_valid); + GANDIVA_EXPORT int64_t gdv_fn_castBIGINT_utf8(int64_t context, const char* data, int32_t data_len); +GANDIVA_EXPORT +int64_t gdv_fn_castBIGINT_or_null_utf8(int64_t context, const char* data, int32_t data_len, bool in_valid, bool* out_valid); + GANDIVA_EXPORT float gdv_fn_castFLOAT4_utf8(int64_t context, const char* data, int32_t data_len); +GANDIVA_EXPORT +float gdv_fn_castFLOAT4_or_null_utf8(int64_t context, const char* data, int32_t data_len, bool in_valid, bool* out_valid); + GANDIVA_EXPORT double gdv_fn_castFLOAT8_utf8(int64_t context, const char* data, int32_t data_len); +GANDIVA_EXPORT +double gdv_fn_castFLOAT8_or_null_utf8(int64_t context, const char* data, int32_t data_len, bool in_valid, bool* out_valid); + GANDIVA_EXPORT const char* gdv_fn_castVARCHAR_int32_int64(int64_t context, int32_t value, int64_t len, int32_t* out_len);