Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fields of ScalarUDF , AggregateUDF and WindowUDF non pub #8079

Merged
merged 3 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
}
}
Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
match fun.signature.volatility {
match fun.signature().volatility {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ impl SessionContext {
self.state
.write()
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
Expand All @@ -820,7 +820,7 @@ impl SessionContext {
self.state
.write()
.aggregate_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers a window UDF within this context.
Expand All @@ -834,7 +834,7 @@ impl SessionContext {
self.state
.write()
.window_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Creates a [`DataFrame`] for reading a data source.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_physical_name(&fun.name, false, args)
create_function_physical_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
Expand Down Expand Up @@ -250,7 +250,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
for e in args {
names.push(create_physical_name(e, false)?);
}
Ok(format!("{}({})", fun.name, names.join(",")))
Ok(format!("{}({})", fun.name(), names.join(",")))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
Expand Down
14 changes: 8 additions & 6 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ impl Between {
}
}

/// ScalarFunction expression
/// ScalarFunction expression invokes a built-in scalar function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
/// The function
Expand All @@ -351,7 +351,9 @@ impl ScalarFunction {
}
}

/// ScalarUDF expression
/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`]
///
/// [`ScalarUDF`]: crate::ScalarUDF
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarUDF {
/// The function
Expand Down Expand Up @@ -1178,7 +1180,7 @@ impl fmt::Display for Expr {
fmt_function(f, &func.fun.to_string(), false, &func.args, true)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
fmt_function(f, &fun.name, false, args, true)
fmt_function(f, fun.name(), false, args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1225,7 +1227,7 @@ impl fmt::Display for Expr {
order_by,
..
}) => {
fmt_function(f, &fun.name, false, args, true)?;
fmt_function(f, fun.name(), false, args, true)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
}
Expand Down Expand Up @@ -1512,7 +1514,7 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_name(&fun.name, false, args)
create_function_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1565,7 +1567,7 @@ fn create_name(e: &Expr) -> Result<String> {
if let Some(ob) = order_by {
info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob));
}
Ok(format!("{}({}){}", fun.name, names.join(","), info))
Ok(format!("{}({}){}", fun.name(), names.join(","), info))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
Ok(fun.return_type(&data_types)?)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece of code looks really clean.

}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let data_types = args
Expand Down Expand Up @@ -115,7 +115,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
fun.return_type(&data_types)
}
Expr::Not(_)
| Expr::IsNull(_)
Expand Down
47 changes: 40 additions & 7 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
// specific language governing permissions and limitations
// under the License.

//! Udaf module contains functions and structs supporting user-defined aggregate functions.
//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::Expr;
use crate::{Accumulator, Expr};
use crate::{
AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

Expand All @@ -46,15 +48,15 @@ use std::sync::Arc;
#[derive(Clone)]
pub struct AggregateUDF {
/// name
pub name: String,
name: String,
/// Signature (input arguments)
pub signature: Signature,
signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
return_type: ReturnTypeFunction,
/// actual implementation
pub accumulator: AccumulatorFactoryFunction,
accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
pub state_type: StateTypeFunction,
state_type: StateTypeFunction,
}

impl Debug for AggregateUDF {
Expand Down Expand Up @@ -112,4 +114,35 @@ impl AggregateUDF {
order_by: None,
})
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}

/// Return the type of the function given its input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another change I did was to move the handling of function pointers into the implementation of AggregateUDF etc -- so rather than passing out parts of themselves (a function pointer) the rest of DataFusion now just calls the function to get the appropriate information

This is in preparation for potentially changing the implementation of *UDF internally

let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}

/// Return an accumualator the given aggregate, given
/// its return datatype.
pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(return_type)
}

/// Return the type of the intermediate state used by this aggregator, given
/// its return datatype. Supports multi-phase aggregations
pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
// old API returns an Arc for some reason, try and unwrap it here
let res = (self.state_type)(return_type)?;
Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone()))
}
}
50 changes: 41 additions & 9 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@
// specific language governing permissions and limitations
// under the License.

//! Udf module contains foundational types that are used to represent UDFs in DataFusion.
//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;

/// Logical representation of a UDF.
/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input.
///
/// This struct contains the information DataFusion needs to plan and invoke
/// functions such name, type signature, return type, and actual implementation.
///
#[derive(Clone)]
pub struct ScalarUDF {
/// name
pub name: String,
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
Expand All @@ -40,7 +48,7 @@ pub struct ScalarUDF {
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
pub fun: ScalarFunctionImplementation,
fun: ScalarFunctionImplementation,
}

impl Debug for ScalarUDF {
Expand Down Expand Up @@ -89,4 +97,28 @@ impl ScalarUDF {
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args))
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}

/// Return the type of the function given its input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}

/// Return the actual implementation
pub fn fun(&self) -> ScalarFunctionImplementation {
self.fun.clone()
}

// TODO maybe add an invoke() method that runs the actual function?
}
44 changes: 34 additions & 10 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
// specific language governing permissions and limitations
// under the License.

//! Support for user-defined window (UDWF) window functions
//! [`WindowUDF`]: User Defined Window Functions

use crate::{
Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature,
WindowFrame,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::{
fmt::{self, Debug, Display, Formatter},
sync::Arc,
};

use crate::{
Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};

/// Logical representation of a user-defined window function (UDWF)
/// A UDWF is different from a UDF in that it is stateful across batches.
///
Expand All @@ -35,13 +37,13 @@ use crate::{
#[derive(Clone)]
pub struct WindowUDF {
/// name
pub name: String,
name: String,
/// signature
pub signature: Signature,
signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
return_type: ReturnTypeFunction,
/// Return the partition evaluator
pub partition_evaluator_factory: PartitionEvaluatorFactory,
partition_evaluator_factory: PartitionEvaluatorFactory,
}

impl Debug for WindowUDF {
Expand Down Expand Up @@ -86,7 +88,7 @@ impl WindowUDF {
partition_evaluator_factory: &PartitionEvaluatorFactory,
) -> Self {
Self {
name: name.to_owned(),
name: name.to_string(),
signature: signature.clone(),
return_type: return_type.clone(),
partition_evaluator_factory: partition_evaluator_factory.clone(),
Expand Down Expand Up @@ -115,4 +117,26 @@ impl WindowUDF {
window_frame,
})
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}

/// Return the type of the function given its input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}

/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}
}
12 changes: 4 additions & 8 deletions datafusion/expr/src/window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,8 @@ impl WindowFunction {
WindowFunction::BuiltInWindowFunction(fun) => {
fun.return_type(input_expr_types)
}
WindowFunction::AggregateUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::WindowUDF(fun) => {
Ok((*(fun.return_type)(input_expr_types)?).clone())
}
WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this now reads more nicely and idomatically

WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types),
}
}
}
Expand Down Expand Up @@ -234,8 +230,8 @@ impl WindowFunction {
match self {
WindowFunction::AggregateFunction(fun) => fun.signature(),
WindowFunction::BuiltInWindowFunction(fun) => fun.signature(),
WindowFunction::AggregateUDF(fun) => fun.signature.clone(),
WindowFunction::WindowUDF(fun) => fun.signature.clone(),
WindowFunction::AggregateUDF(fun) => fun.signature().clone(),
WindowFunction::WindowUDF(fun) => fun.signature().clone(),
}
}
}
Expand Down
Loading