Skip to content

Commit

Permalink
Clean internal implementation of WindowUDF (#8746)
Browse files Browse the repository at this point in the history
* Clean internal implementation of WindowUDF

* fix doc
  • Loading branch information
guojidan authored Jan 5, 2024
1 parent 29f23eb commit 98f02ff
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 50 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion_expr::{
/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance.
///
/// To do so, we must implement the `WindowUDFImpl` trait.
#[derive(Debug, Clone)]
struct SmoothItUdf {
signature: Signature,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ impl OddCounter {
}

fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
#[derive(Debug, Clone)]
struct SimpleWindowUDF {
signature: Signature,
return_type: DataType,
Expand Down
12 changes: 12 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl};
use arrow::datatypes::DataType;
use datafusion_common::{Column, Result};
use std::any::Any;
use std::fmt::Debug;
use std::ops::Not;
use std::sync::Arc;

Expand Down Expand Up @@ -1078,6 +1079,17 @@ pub struct SimpleWindowUDF {
partition_evaluator_factory: PartitionEvaluatorFactory,
}

impl Debug for SimpleWindowUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("WindowUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("return_type", &"<func>")
.field("partition_evaluator_factory", &"<FUNC>")
.finish()
}
}

impl SimpleWindowUDF {
/// Create a new `SimpleWindowUDF` from a name, input types, return type and
/// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
Expand Down
138 changes: 88 additions & 50 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,49 +34,42 @@ use std::{
///
/// See the documetnation on [`PartitionEvaluator`] for more details
///
/// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`].
///
/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udf.rs`].
///
/// # API Note
/// This is a separate struct from `WindowUDFImpl` to maintain backwards
/// compatibility with the older API.
///
/// [`PartitionEvaluator`]: crate::PartitionEvaluator
#[derive(Clone)]
/// [`create_udwf`]: crate::expr_fn::create_udwf
/// [`simple_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs
/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs
#[derive(Debug, Clone)]
pub struct WindowUDF {
/// name
name: String,
/// signature
signature: Signature,
/// Return type
return_type: ReturnTypeFunction,
/// Return the partition evaluator
partition_evaluator_factory: PartitionEvaluatorFactory,
}

impl Debug for WindowUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("WindowUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("return_type", &"<func>")
.field("partition_evaluator_factory", &"<func>")
.finish_non_exhaustive()
}
inner: Arc<dyn WindowUDFImpl>,
}

/// Defines how the WindowUDF is shown to users
impl Display for WindowUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name)
write!(f, "{}", self.name())
}
}

impl PartialEq for WindowUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.signature == other.signature
self.name() == other.name() && self.signature() == other.signature()
}
}

impl Eq for WindowUDF {}

impl std::hash::Hash for WindowUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
self.name().hash(state);
self.signature().hash(state);
}
}

Expand All @@ -92,40 +85,31 @@ impl WindowUDF {
return_type: &ReturnTypeFunction,
partition_evaluator_factory: &PartitionEvaluatorFactory,
) -> Self {
Self {
name: name.to_string(),
Self::new_from_impl(WindowUDFLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
partition_evaluator_factory: partition_evaluator_factory.clone(),
}
})
}

/// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object
///
/// Note this is the same as using the `From` impl (`WindowUDF::from`)
pub fn new_from_impl<F>(fun: F) -> WindowUDF
where
F: WindowUDFImpl + Send + Sync + 'static,
F: WindowUDFImpl + 'static,
{
let arc_fun = Arc::new(fun);
let captured_self = arc_fun.clone();
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
let return_type = captured_self.return_type(arg_types)?;
Ok(Arc::new(return_type))
});

let captured_self = arc_fun.clone();
let partition_evaluator_factory: PartitionEvaluatorFactory =
Arc::new(move || captured_self.partition_evaluator());

Self {
name: arc_fun.name().to_string(),
signature: arc_fun.signature().clone(),
return_type: return_type.clone(),
partition_evaluator_factory,
inner: Arc::new(fun),
}
}

/// Return the underlying [`WindowUDFImpl`] trait object for this function
pub fn inner(&self) -> Arc<dyn WindowUDFImpl> {
self.inner.clone()
}

/// creates a [`Expr`] that calls the window function given
/// the `partition_by`, `order_by`, and `window_frame` definition
///
Expand All @@ -150,25 +134,29 @@ impl WindowUDF {
}

/// Returns this function's name
///
/// See [`WindowUDFImpl::name`] for more details.
pub fn name(&self) -> &str {
&self.name
self.inner.name()
}

/// Returns this function's signature (what input types are accepted)
///
/// See [`WindowUDFImpl::signature`] for more details.
pub fn signature(&self) -> &Signature {
&self.signature
self.inner.signature()
}

/// Return the type of the function given its input types
///
/// See [`WindowUDFImpl::return_type`] for more details.
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
self.inner.return_type(args)
}

/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn PartitionEvaluator>> {
(self.partition_evaluator_factory)()
self.inner.partition_evaluator()
}
}

Expand Down Expand Up @@ -198,6 +186,7 @@ where
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame};
/// # use datafusion_expr::{WindowUDFImpl, WindowUDF};
/// #[derive(Debug, Clone)]
/// struct SmoothIt {
/// signature: Signature
/// };
Expand Down Expand Up @@ -236,7 +225,7 @@ where
/// WindowFrame::new(false),
/// );
/// ```
pub trait WindowUDFImpl {
pub trait WindowUDFImpl: Debug + Send + Sync {
/// Returns this object as an [`Any`] trait object
fn as_any(&self) -> &dyn Any;

Expand All @@ -254,3 +243,52 @@ pub trait WindowUDFImpl {
/// Invoke the function, returning the [`PartitionEvaluator`] instance
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
}

/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers
/// of the older API (see <https://github.com/apache/arrow-datafusion/pull/8719>
/// for more details)
pub struct WindowUDFLegacyWrapper {
/// name
name: String,
/// signature
signature: Signature,
/// Return type
return_type: ReturnTypeFunction,
/// Return the partition evaluator
partition_evaluator_factory: PartitionEvaluatorFactory,
}

impl Debug for WindowUDFLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("WindowUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("return_type", &"<func>")
.field("partition_evaluator_factory", &"<func>")
.finish_non_exhaustive()
}
}

impl WindowUDFImpl for WindowUDFLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}
}
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,7 @@ fn roundtrip_window() {
}
}

#[derive(Debug, Clone)]
struct SimpleWindowUDF {
signature: Signature,
}
Expand Down

0 comments on commit 98f02ff

Please sign in to comment.