Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): building blocks for expression expansion sets #9231

Merged
merged 4 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 8 additions & 242 deletions polars/polars-lazy/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
@@ -1,255 +1,14 @@
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::Deref;

use polars_core::prelude::*;
use polars_core::utils::get_supertype;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde::{Deserializer, Serializer};

pub use super::expr_dyn_fn::*;
use crate::dsl::function_expr::FunctionExpr;
use crate::prelude::*;

/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`
pub trait SeriesUdf: Send + Sync {
fn call_udf(&self, s: &mut [Series]) -> PolarsResult<Option<Series>>;
}

impl<F> SeriesUdf for F
where
F: Fn(&mut [Series]) -> PolarsResult<Option<Series>> + Send + Sync,
{
fn call_udf(&self, s: &mut [Series]) -> PolarsResult<Option<Series>> {
self(s)
}
}

impl Debug for dyn SeriesUdf {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "SeriesUdf")
}
}

/// A wrapper trait for any binary closure `Fn(Series, Series) -> PolarsResult<Series>`
pub trait SeriesBinaryUdf: Send + Sync {
fn call_udf(&self, a: Series, b: Series) -> PolarsResult<Series>;
}

impl<F> SeriesBinaryUdf for F
where
F: Fn(Series, Series) -> PolarsResult<Series> + Send + Sync,
{
fn call_udf(&self, a: Series, b: Series) -> PolarsResult<Series> {
self(a, b)
}
}

impl Debug for dyn SeriesBinaryUdf {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "SeriesBinaryUdf")
}
}

impl Default for SpecialEq<Arc<dyn SeriesBinaryUdf>> {
fn default() -> Self {
panic!("implementation error");
}
}

impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {
fn default() -> Self {
let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None;
SpecialEq::new(Arc::new(output_field))
}
}

pub trait RenameAliasFn: Send + Sync {
fn call(&self, name: &str) -> PolarsResult<String>;
}

impl<F: Fn(&str) -> PolarsResult<String> + Send + Sync> RenameAliasFn for F {
fn call(&self, name: &str) -> PolarsResult<String> {
self(name)
}
}

impl Debug for dyn RenameAliasFn {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "RenameAliasFn")
}
}

#[derive(Clone)]
/// Wrapper type that has special equality properties
/// depending on the inner type specialization
pub struct SpecialEq<T>(T);

#[cfg(feature = "serde")]
impl<T: Serialize> Serialize for SpecialEq<T> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
self.0.serialize(serializer)
}
}

#[cfg(feature = "serde")]
impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<T> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
let t = T::deserialize(deserializer)?;
Ok(SpecialEq(t))
}
}

impl<T> SpecialEq<T> {
pub fn new(val: T) -> Self {
SpecialEq(val)
}
}

impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

impl PartialEq for SpecialEq<Series> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl<T> Debug for SpecialEq<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "no_eq")
}
}

impl<T> Deref for SpecialEq<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0
}
}

pub trait BinaryUdfOutputField: Send + Sync {
fn get_field(
&self,
input_schema: &Schema,
cntxt: Context,
field_a: &Field,
field_b: &Field,
) -> Option<Field>;
}

impl<F> BinaryUdfOutputField for F
where
F: Fn(&Schema, Context, &Field, &Field) -> Option<Field> + Send + Sync,
{
fn get_field(
&self,
input_schema: &Schema,
cntxt: Context,
field_a: &Field,
field_b: &Field,
) -> Option<Field> {
self(input_schema, cntxt, field_a, field_b)
}
}

pub trait FunctionOutputField: Send + Sync {
fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field;
}

pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;

impl Default for GetOutput {
fn default() -> Self {
SpecialEq::new(Arc::new(
|_input_schema: &Schema, _cntxt: Context, fields: &[Field]| fields[0].clone(),
))
}
}

impl GetOutput {
pub fn same_type() -> Self {
Default::default()
}

pub fn from_type(dt: DataType) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
Field::new(flds[0].name(), dt.clone())
}))
}

pub fn map_field<F: 'static + Fn(&Field) -> Field + Send + Sync>(f: F) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
f(&flds[0])
}))
}

pub fn map_fields<F: 'static + Fn(&[Field]) -> Field + Send + Sync>(f: F) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
f(flds)
}))
}

pub fn map_dtype<F: 'static + Fn(&DataType) -> DataType + Send + Sync>(f: F) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
let mut fld = flds[0].clone();
let new_type = f(fld.data_type());
fld.coerce(new_type);
fld
}))
}

pub fn float_type() -> Self {
Self::map_dtype(|dt| match dt {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
})
}

pub fn super_type() -> Self {
Self::map_dtypes(|dtypes| {
let mut st = dtypes[0].clone();
for dt in &dtypes[1..] {
st = get_supertype(&st, dt).unwrap();
}
st
})
}

pub fn map_dtypes<F>(f: F) -> Self
where
F: 'static + Fn(&[&DataType]) -> DataType + Send + Sync,
{
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
let mut fld = flds[0].clone();
let dtypes = flds.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
let new_type = f(&dtypes);
fld.coerce(new_type);
fld
}))
}
}

impl<F> FunctionOutputField for F
where
F: Fn(&Schema, Context, &[Field]) -> Field + Send + Sync,
{
fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field {
self(input_schema, cntxt, fields)
}
}

#[derive(PartialEq, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AggExpr {
Expand Down Expand Up @@ -399,6 +158,13 @@ pub enum Expr {
input: Box<Expr>,
id: usize,
},
/// Expressions in this node should only be expanding
/// e.g.
/// `Expr::Columns`
/// `Expr::Dtypes`
/// `Expr::Wildcard`
/// `Expr::Exclude`
Selector(super::selector::Selector),
}

// TODO! derive. This is only a temporary fix
Expand Down
Loading