diff --git a/cpp/src/gandiva/function_registry_math_ops.cc b/cpp/src/gandiva/function_registry_math_ops.cc index f3d80110152d5..e0640a0775c0c 100644 --- a/cpp/src/gandiva/function_registry_math_ops.cc +++ b/cpp/src/gandiva/function_registry_math_ops.cc @@ -108,6 +108,9 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() { "gdv_fn_random", NativeFunction::kNeedsFunctionHolder), NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(), kResultNullNever, "gdv_fn_random_with_seed", + NativeFunction::kNeedsFunctionHolder), + NativeFunction("random", {"rand"}, DataTypeVector{int64()}, float64(), + kResultNullNever, "gdv_fn_random_with_seed64", NativeFunction::kNeedsFunctionHolder)}; return math_fn_registry_; diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index a211303b7cbbb..a705acaa64017 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -147,6 +147,12 @@ double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity) { return (*holder)(); } +double gdv_fn_random_with_seed64(int64_t ptr, int64_t seed, bool seed_validity) { + gandiva::RandomGeneratorHolder* holder = + reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr); + return (*holder)(); +} + int64_t gdv_fn_to_date_utf8_utf8(int64_t context_ptr, int64_t holder_ptr, const char* data, int data_len, bool in1_validity, const char* pattern, int pattern_len, bool in2_validity, @@ -1512,6 +1518,10 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args, reinterpret_cast<void*>(gdv_fn_random_with_seed)); + args = {types->i64_type(), types->i64_type(), types->i1_type()}; + engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed64", types->double_type(), args, + reinterpret_cast<void*>(gdv_fn_random_with_seed64)); + args = {types->i64_type(), // int64_t context_ptr types->i8_ptr_type(), // const char* data types->i32_type()}; // int32_t lenr diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index f56a020c46ce6..5e6de27204271 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -98,6 +98,12 @@ bool in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, bool in_va int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len, int64_t* ret_time); +double gdv_fn_random(int64_t ptr); + +double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity); + +double gdv_fn_random_with_seed64(int64_t ptr, int64_t seed, bool seed_validity); + GANDIVA_EXPORT const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len, int32_t* out_len); diff --git a/cpp/src/gandiva/random_generator_holder.cc b/cpp/src/gandiva/random_generator_holder.cc index 3471c87d92b89..1f4f3d020f642 100644 --- a/cpp/src/gandiva/random_generator_holder.cc +++ b/cpp/src/gandiva/random_generator_holder.cc @@ -35,11 +35,16 @@ Status RandomGeneratorHolder::Make(const FunctionNode& node, auto literal_type = literal->return_type()->id(); ARROW_RETURN_IF( - literal_type != arrow::Type::INT32, - Status::Invalid("'random' function requires an int32 literal as parameter")); + literal_type != arrow::Type::INT32 && literal_type != arrow::Type::INT64, + Status::Invalid("'random' function requires an int32/int64 literal as parameter")); - *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder( + if (literal_type == arrow::Type::INT32) { + *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder( literal->is_null() ? 0 : arrow::util::get<int32_t>(literal->holder()))); + } else { + *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder( + literal->is_null() ? 0 : arrow::util::get<int64_t>(literal->holder()))); + } return Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/random_generator_holder.h b/cpp/src/gandiva/random_generator_holder.h index 65b6607e87840..7cedee0b454f1 100644 --- a/cpp/src/gandiva/random_generator_holder.h +++ b/cpp/src/gandiva/random_generator_holder.h @@ -46,6 +46,11 @@ class GANDIVA_EXPORT RandomGeneratorHolder : public FunctionHolder { generator_.seed(static_cast<uint64_t>(seed64)); } + explicit RandomGeneratorHolder(int64_t seed64) : distribution_(0, 1) { + seed64 = (seed64 ^ 0x00000005DEECE66D) & 0x0000ffffffffffff; + generator_.seed(static_cast<uint64_t>(seed64)); + } + RandomGeneratorHolder() : distribution_(0, 1) { generator_.seed(::arrow::internal::GetRandomSeed()); } diff --git a/cpp/src/gandiva/random_generator_holder_test.cc b/cpp/src/gandiva/random_generator_holder_test.cc index 4b16c1b7d0d9a..2c91187bffa7a 100644 --- a/cpp/src/gandiva/random_generator_holder_test.cc +++ b/cpp/src/gandiva/random_generator_holder_test.cc @@ -85,6 +85,27 @@ TEST_F(TestRandGenHolder, WithValidSeeds) { EXPECT_NE(random_1(), random_2()); } +TEST_F(TestRandGenHolder, WithValidSeedsInLongType) { + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1; + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2; + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_3; + FunctionNode rand_func_1 = BuildRandWithSeedFunc(100L, false); + FunctionNode rand_func_2 = BuildRandWithSeedFunc(1000L, false); + FunctionNode rand_func_3 = BuildRandWithSeedFunc(100000L, false); + auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1); + EXPECT_EQ(status.ok(), true) << status.message(); + status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2); + EXPECT_EQ(status.ok(), true) << status.message(); + status = RandomGeneratorHolder::Make(rand_func_3, &rand_gen_holder_3); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& random_1 = *rand_gen_holder_1; + auto& random_2 = *rand_gen_holder_2; + auto& random_3 = *rand_gen_holder_3; + EXPECT_NE(random_2(), random_3()); + EXPECT_NE(random_1(), random_2()); +} + TEST_F(TestRandGenHolder, WithInValidSeed) { std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1; std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;