Skip to content

Commit

Permalink
fix: Various schema corrections (#18474)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 30, 2024
1 parent 37e9ccd commit a5dc30d
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 173 deletions.
7 changes: 6 additions & 1 deletion crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ impl AExpr {
*nested = 0;
Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE))
},
Window { function, .. } => {
Window {
function, options, ..
} => {
if let WindowType::Over(mapping) = options {
*nested += matches!(mapping, WindowMapping::Join) as u8;
}
let e = arena.get(*function);
e.to_field_impl(schema, arena, nested)
},
Expand Down
74 changes: 15 additions & 59 deletions crates/polars-plan/src/plans/conversion/expr_to_ir.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::plans::conversion::functions::convert_functions;

pub fn to_expr_ir(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<ExprIR> {
let mut state = ConversionContext::new();
Expand Down Expand Up @@ -40,12 +41,12 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<Node> {
}

#[derive(Default)]
struct ConversionContext {
output_name: OutputName,
pub(super) struct ConversionContext {
pub(super) output_name: OutputName,
/// Remove alias from the expressions and set as [`OutputName`].
prune_alias: bool,
pub(super) prune_alias: bool,
/// If an `alias` is encountered prune and ignore it.
ignore_alias: bool,
pub(super) ignore_alias: bool,
}

impl ConversionContext {
Expand All @@ -68,14 +69,17 @@ fn to_aexprs(
.collect()
}

fn set_function_output_name<F>(e: &[ExprIR], state: &mut ConversionContext, function_fmt: F)
where
F: FnOnce() -> Cow<'static, str>,
pub(super) fn set_function_output_name<F>(
e: &[ExprIR],
state: &mut ConversionContext,
function_fmt: F,
) where
F: FnOnce() -> PlSmallStr,
{
if state.output_name.is_none() {
if e.is_empty() {
let s = function_fmt();
state.output_name = OutputName::LiteralLhs(PlSmallStr::from_str(s.as_ref()));
state.output_name = OutputName::LiteralLhs(s);
} else {
state.output_name = e[0].output_name_inner().clone();
}
Expand Down Expand Up @@ -117,7 +121,7 @@ fn to_aexpr_impl_materialized_lit(

/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation.
#[recursive]
fn to_aexpr_impl(
pub(super) fn to_aexpr_impl(
expr: Expr,
arena: &mut Arena<AExpr>,
state: &mut ConversionContext,
Expand Down Expand Up @@ -281,7 +285,7 @@ fn to_aexpr_impl(
options,
} => {
let e = to_expr_irs(input, arena)?;
set_function_output_name(&e, state, || Cow::Borrowed(options.fmt_str));
set_function_output_name(&e, state, || PlSmallStr::from_static(options.fmt_str));
AExpr::AnonymousFunction {
input: e,
function,
Expand All @@ -293,55 +297,7 @@ fn to_aexpr_impl(
input,
function,
options,
} => {
match function {
// This can be created by col(*).is_null() on empty dataframes.
FunctionExpr::Boolean(
BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal,
) if input.is_empty() => {
return to_aexpr_impl(lit(true), arena, state);
},
// Convert to binary expression as the optimizer understands those.
// Don't exceed 128 expressions as we might stackoverflow.
FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => {
if input.len() < 128 {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_and(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
}
},
FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => {
if input.len() < 128 {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_or(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
}
},
_ => {},
}

let e = to_expr_irs(input, arena)?;

if state.output_name.is_none() {
// Handles special case functions like `struct.field`.
if let Some(name) = function.output_name() {
state.output_name = name
} else {
set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function)));
}
}
AExpr::Function {
input: e,
function,
options,
}
},
} => return convert_functions(input, function, options, arena, state),
Expr::Window {
function,
partition_by,
Expand Down
66 changes: 66 additions & 0 deletions crates/polars-plan/src/plans/conversion/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use arrow::legacy::error::PolarsResult;
use polars_utils::arena::{Arena, Node};
use polars_utils::format_pl_smallstr;

use super::*;
use crate::dsl::{Expr, FunctionExpr};
use crate::plans::AExpr;
use crate::prelude::FunctionOptions;

pub(super) fn convert_functions(
input: Vec<Expr>,
function: FunctionExpr,
options: FunctionOptions,
arena: &mut Arena<AExpr>,
state: &mut ConversionContext,
) -> PolarsResult<Node> {
match function {
// This can be created by col(*).is_null() on empty dataframes.
FunctionExpr::Boolean(BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal)
if input.is_empty() =>
{
return to_aexpr_impl(lit(true), arena, state);
},
// Convert to binary expression as the optimizer understands those.
// Don't exceed 128 expressions as we might stackoverflow.
FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => {
if input.len() < 128 {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_and(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
}
},
FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => {
if input.len() < 128 {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_or(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
}
},
_ => {},
}

let e = to_expr_irs(input, arena)?;

if state.output_name.is_none() {
// Handles special case functions like `struct.field`.
if let Some(name) = function.output_name() {
state.output_name = name
} else {
set_function_output_name(&e, state, || format_pl_smallstr!("{}", &function));
}
}

let ae_function = AExpr::Function {
input: e,
function,
options,
};
Ok(arena.add(ae_function))
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ mod ir_to_dsl;
mod scans;
mod stack_opt;

use std::borrow::Cow;
use std::sync::{Arc, Mutex, RwLock};

pub use dsl_to_ir::*;
Expand All @@ -21,6 +20,7 @@ pub use ir_to_dsl::*;
use polars_core::prelude::*;
use polars_utils::vec::ConvertVec;
use recursive::recursive;
mod functions;
pub(crate) mod type_coercion;

pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,13 @@ pub(super) fn process_binary(
st = String
}

// only cast if the type is not already the super type.
// TODO! raise here?
// We should at least never cast to Unknown.
if matches!(st, DataType::Unknown(UnknownKind::Any)) {
return Ok(None);
}

// Only cast if the type is not already the super type.
// this can prevent an expensive flattening and subsequent aggregation
// in a group_by context. To be able to cast the groups need to be
// flattened
Expand Down
83 changes: 83 additions & 0 deletions crates/polars-plan/src/plans/conversion/type_coercion/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use either::Either;

use super::*;

pub(super) fn get_function_dtypes(
input: &[ExprIR],
expr_arena: &Arena<AExpr>,
input_schema: &Schema,
function: &FunctionExpr,
mut options: FunctionOptions,
) -> PolarsResult<Either<Vec<DataType>, AExpr>> {
let mut early_return = move || {
// Next iteration this will not hit anymore as options is updated.
options.cast_to_supertypes = None;
Ok(Either::Right(AExpr::Function {
function: function.clone(),
input: input.to_vec(),
options,
}))
};

let mut dtypes = Vec::with_capacity(input.len());
let mut first = true;
for e in input {
let Some((_, dtype)) = get_aexpr_and_type(expr_arena, e.node(), input_schema) else {
return early_return();
};

if first {
check_namespace(function, &dtype)?;
first = false;
}
// Ignore Unknown in the inputs.
// We will raise if we cannot find the supertype later.
match dtype {
DataType::Unknown(UnknownKind::Any) => {
return early_return();
},
_ => dtypes.push(dtype),
}
}

if dtypes.iter().all_equal() {
return early_return();
}
Ok(Either::Left(dtypes))
}

// `str` namespace belongs to `String`
// `cat` namespace belongs to `Categorical` etc.
fn check_namespace(function: &FunctionExpr, first_dtype: &DataType) -> PolarsResult<()> {
match function {
#[cfg(feature = "strings")]
FunctionExpr::StringExpr(_) => {
polars_ensure!(first_dtype == &DataType::String, InvalidOperation: "expected String type, got: {}", first_dtype)
},
FunctionExpr::BinaryExpr(_) => {
polars_ensure!(first_dtype == &DataType::Binary, InvalidOperation: "expected Binary type, got: {}", first_dtype)
},
#[cfg(feature = "temporal")]
FunctionExpr::TemporalExpr(_) => {
polars_ensure!(first_dtype.is_temporal(), InvalidOperation: "expected Date(time)/Duration type, got: {}", first_dtype)
},
FunctionExpr::ListExpr(_) => {
polars_ensure!(matches!(first_dtype, DataType::List(_)), InvalidOperation: "expected List type, got: {}", first_dtype)
},
#[cfg(feature = "dtype-array")]
FunctionExpr::ArrayExpr(_) => {
polars_ensure!(matches!(first_dtype, DataType::Array(_, _)), InvalidOperation: "expected Array type, got: {}", first_dtype)
},
#[cfg(feature = "dtype-struct")]
FunctionExpr::StructExpr(_) => {
polars_ensure!(matches!(first_dtype, DataType::Struct(_)), InvalidOperation: "expected Struct type, got: {}", first_dtype)
},
#[cfg(feature = "dtype-categorical")]
FunctionExpr::Categorical(_) => {
polars_ensure!(matches!(first_dtype, DataType::Categorical(_, _)), InvalidOperation: "expected Struct type, got: {}", first_dtype)
},
_ => {},
}

Ok(())
}
Loading

0 comments on commit a5dc30d

Please sign in to comment.