Skip to content

Commit

Permalink
Move Nanvl and random functions to datafusion-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Apr 9, 2024
1 parent ff2d202 commit 9bc03ed
Show file tree
Hide file tree
Showing 19 changed files with 383 additions and 242 deletions.
2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility,
expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan,
LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_functions::math;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;
Expand Down Expand Up @@ -383,17 +384,17 @@ fn test_const_evaluator_scalar_functions() {

// volatile / stable functions should not be evaluated
// rand() + (1 + 2) --> rand() + 3
let fun = BuiltinScalarFunction::Random;
assert_eq!(fun.volatility(), Volatility::Volatile);
let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![]));
let fun = math::random();
assert_eq!(fun.signature().volatility, Volatility::Volatile);
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
let expr = rand.clone() + (lit(1) + lit(2));
let expected = rand + lit(3);
test_evaluate(expr, expected);

// parenthesization matters: can't rewrite
// (rand() + 1) + 2 --> (rand() + 1) + 2)
let fun = BuiltinScalarFunction::Random;
let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![]));
let fun = math::random();
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
let expr = (rand + lit(1)) + lit(2);
test_evaluate(expr.clone(), expr);
}
Expand Down
21 changes: 0 additions & 21 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ pub enum BuiltinScalarFunction {
Exp,
/// factorial
Factorial,
/// nanvl
Nanvl,
// string functions
/// concat
Concat,
Expand All @@ -56,8 +54,6 @@ pub enum BuiltinScalarFunction {
EndsWith,
/// initcap
InitCap,
/// random
Random,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -114,14 +110,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,

// Volatile builtin functions
BuiltinScalarFunction::Random => Volatility::Volatile,
}
}

Expand Down Expand Up @@ -152,16 +144,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::Random => Ok(Float64),
BuiltinScalarFunction::EndsWith => Ok(Boolean),

BuiltinScalarFunction::Factorial => Ok(Int64),

BuiltinScalarFunction::Nanvl => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
match input_expr_types[0] {
Float32 => Ok(Float32),
Expand Down Expand Up @@ -199,11 +185,6 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
),
BuiltinScalarFunction::Factorial => {
Signature::uniform(1, vec![Int64], self.volatility())
}
Expand Down Expand Up @@ -240,8 +221,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Random => &["random"],

// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],
Expand Down
11 changes: 2 additions & 9 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1903,8 +1903,8 @@ mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
case, lit, ColumnarValue, Expr, ScalarFunctionDefinition, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
Expand Down Expand Up @@ -2018,13 +2018,6 @@ mod test {

#[test]
fn test_is_volatile_scalar_func_definition() {
// BuiltIn
assert!(
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
.is_volatile()
.unwrap()
);

// UDF
#[derive(Debug)]
struct TestScalarUDF {
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,6 @@ pub fn concat_ws(sep: Expr, values: Vec<Expr>) -> Expr {
))
}

/// Returns a random value in the range 0.0 <= x < 1.0
pub fn random() -> Expr {
Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Random, vec![]))
}

/// Returns the approximate number of distinct input values.
/// This function provides an approximation of count(DISTINCT x).
/// Zero is returned if all input values are null.
Expand Down Expand Up @@ -550,7 +545,6 @@ nary_scalar_expr!(
"concatenates several strings, placing a seperator between each one"
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
pub fn case(expr: Expr) -> CaseBuilder {
Expand Down Expand Up @@ -922,7 +916,6 @@ mod test {
test_unary_scalar_expr!(Factorial, factorial);
test_unary_scalar_expr!(Ceil, ceil);
test_unary_scalar_expr!(Exp, exp);
test_scalar_expr!(Nanvl, nanvl, x, y);

test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(EndsWith, ends_with, string, characters);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub enum Volatility {
Stable,
/// A volatile function may change the return value from evaluation to evaluation.
/// Multiple invocations of a volatile function may return different results when used in the
/// same query. An example of this is [super::BuiltinScalarFunction::Random]. DataFusion
/// same query. An example of this is the random() function. DataFusion
/// can not evaluate such functions during planning.
/// In the query `select col1, random() from t1`, `random()` function will be evaluated
/// for each output row, resulting in a unique random value for each row.
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ hex = { version = "0.4", optional = true }
itertools = { workspace = true }
log = { workspace = true }
md-5 = { version = "^0.10.0", optional = true }
rand = { workspace = true }
regex = { version = "1.8", optional = true }
sha2 = { version = "^0.10.1", optional = true }
unicode-segmentation = { version = "^1.7.1", optional = true }
Expand Down
16 changes: 16 additions & 0 deletions datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ pub mod iszero;
pub mod lcm;
pub mod log;
pub mod nans;
pub mod nanvl;
pub mod pi;
pub mod power;
pub mod random;
pub mod round;
pub mod trunc;

Expand All @@ -55,9 +57,11 @@ make_udf_function!(lcm::LcmFunc, LCM, lcm);
make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)]));
make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)]));
make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)]));
make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl);
make_udf_function!(pi::PiFunc, PI, pi);
make_udf_function!(power::PowerFunc, POWER, power);
make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None);
make_udf_function!(random::RandomFunc, RANDOM, random);
make_udf_function!(round::RoundFunc, ROUND, round);
make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None);
make_math_unary_udf!(SinFunc, SIN, sin, sin, None);
Expand Down Expand Up @@ -180,6 +184,11 @@ pub mod expr_fn {
super::log10().call(vec![num])
}

#[doc = "returns x if x is not NaN otherwise returns y"]
pub fn nanvl(x: Expr, y: Expr) -> Expr {
super::nanvl().call(vec![x, y])
}

#[doc = "Returns an approximate value of π"]
pub fn pi() -> Expr {
super::pi().call(vec![])
Expand All @@ -195,6 +204,11 @@ pub mod expr_fn {
super::radians().call(vec![num])
}

#[doc = "Returns a random value in the range 0.0 <= x < 1.0"]
pub fn random() -> Expr {
super::random().call(vec![])
}

#[doc = "round to nearest integer"]
pub fn round(args: Vec<Expr>) -> Expr {
super::round().call(args)
Expand Down Expand Up @@ -261,9 +275,11 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
log(),
log2(),
log10(),
nanvl(),
pi(),
power(),
radians(),
random(),
round(),
signum(),
sin(),
Expand Down
Loading

0 comments on commit 9bc03ed

Please sign in to comment.