diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index b00faf4cbbd0e..19e31c8d5c069 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -57,6 +57,10 @@ std::vector GetStringFunctionRegistry() { NativeFunction("upper", DataTypeVector{utf8()}, utf8(), kResultNullIfNull, "upper_utf8", NativeFunction::kNeedsContext), + NativeFunction("castVARCHAR", DataTypeVector{utf8(), int64()}, utf8(), + kResultNullIfNull, "castVARCHAR_utf8_int64", + NativeFunction::kNeedsContext), + NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), kResultNullIfNull, "gdv_fn_like_utf8_utf8", NativeFunction::kNeedsFunctionHolder)}; diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index a0f2c1960556c..5ffaef44efe45 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -160,6 +160,18 @@ char* upper_utf8(int64 context, const char* data, int32 data_len, int32_t* out_l return ret; } +// Truncates the string to given length +FORCE_INLINE +char* castVARCHAR_utf8_int64(int64 context, const char* data, int32 data_len, + int64_t out_len, int32_t* out_length) { + // TODO: handle allocation failures + int32_t len = data_len <= static_cast(out_len) ? data_len : static_cast(out_len); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, len)); + memcpy(ret, data, len); + *out_length = len; + return ret; +} + #define IS_NULL(NAME, TYPE) \ FORCE_INLINE \ bool NAME##_##TYPE(TYPE in, int32 len, boolean is_valid) { return !is_valid; } diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt index 06e29343f5217..df3e7bd9a46c3 100644 --- a/cpp/src/gandiva/tests/CMakeLists.txt +++ b/cpp/src/gandiva/tests/CMakeLists.txt @@ -30,6 +30,7 @@ add_gandiva_test(in_expr_test) add_gandiva_test(null_validity_test) add_gandiva_test(decimal_test) add_gandiva_test(decimal_single_test) +add_gandiva_test(utf8_test) add_gandiva_test(projector_test_static SOURCES projector_test.cc USE_STATIC_LINKING) diff --git a/cpp/src/gandiva/tests/utf8_test.cc b/cpp/src/gandiva/tests/utf8_test.cc index 8129169544c7a..b1ec8da918d31 100644 --- a/cpp/src/gandiva/tests/utf8_test.cc +++ b/cpp/src/gandiva/tests/utf8_test.cc @@ -468,6 +468,55 @@ TEST_F(TestUtf8, TestToDateError) { << status.message(); } +TEST_F(TestUtf8, TestCastVarChar) { + // schema for input fields + auto field_a = field("a", utf8()); + auto field_c = field("c", utf8()); + auto schema = arrow::schema({field_a, field_c}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_c = TreeExprBuilder::MakeField(field_c); + // truncates the string to input length + auto node_b = TreeExprBuilder::MakeLiteral(static_cast(10)); + auto cast_varchar = + TreeExprBuilder::MakeFunction("castVARCHAR", {node_a, node_b}, utf8()); + auto equals = TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_c}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(equals, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_a = MakeArrowArrayUtf8( + {"park", "Sparkle", "bright spark and fire", "fiery SPARK", "मदन"}, + {true, true, false, true, true}); + + auto array_b = + MakeArrowArrayUtf8({"park", "Sparkle", "bright spar", "fiery SPAR", "मदन"}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + auto exp = MakeArrowArrayBool({true, true, false, true, true}, + {true, true, false, true, true}); + + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]); +} + TEST_F(TestUtf8, TestIsNull) { // schema for input fields auto field_a = field("a", utf8()); @@ -492,11 +541,10 @@ TEST_F(TestUtf8, TestIsNull) { // prepare input record batch auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); - + // Evaluate expression arrow::ArrayVector outputs; status = projector->Evaluate(*in_batch, pool_, &outputs); - // validate results EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}), outputs[0]); // isnull