diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/fused.rs b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/fused.rs index a4b63a9272e4..7ece608af567 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/fused.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/fused.rs @@ -3,14 +3,19 @@ use super::*; pub struct FusedArithmetic {} fn get_expr(input: Vec, op: FusedOperator) -> AExpr { + let mut options = FunctionOptions { + collect_groups: ApplyOptions::ApplyFlat, + cast_to_supertypes: true, + ..Default::default() + }; + // order of operations change because of FMA + // so we must toggle this check off + // it is still safe as it is a trusted operation + unsafe { options.no_check_lengths() } AExpr::Function { input, function: FunctionExpr::Fused(op), - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - cast_to_supertypes: true, - ..Default::default() - }, + options, } } diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/options.rs b/polars/polars-lazy/polars-plan/src/logical_plan/options.rs index a4372c3dbd5f..61d04ba740c4 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/options.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/options.rs @@ -171,6 +171,16 @@ pub enum ApplyOptions { ApplyFlat, } +// a boolean that can only be set to `false` safely +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnsafeBool(bool); +impl Default for UnsafeBool { + fn default() -> Self { + UnsafeBool(true) + } +} + #[derive(Clone, Copy, PartialEq, Eq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct FunctionOptions { @@ -217,6 +227,9 @@ pub struct FunctionOptions { pub pass_name_to_apply: bool, // For example a `unique` or a `slice` pub changes_length: bool, + // Validate the output of a `map`. + // this should always be true or we could OOB + pub check_lengths: UnsafeBool, } impl FunctionOptions { @@ -227,6 +240,14 @@ impl FunctionOptions { pub fn is_groups_sensitive(&self) -> bool { matches!(self.collect_groups, ApplyOptions::ApplyGroups) } + + #[cfg(feature = "fused")] + pub(crate) unsafe fn no_check_lengths(&mut self) { + self.check_lengths = UnsafeBool(false); + } + pub fn check_lengths(&self) -> bool { + self.check_lengths.0 + } } impl Default for FunctionOptions { @@ -240,6 +261,7 @@ impl Default for FunctionOptions { allow_rename: false, pass_name_to_apply: false, changes_length: false, + check_lengths: UnsafeBool(true), } } } diff --git a/polars/polars-lazy/src/physical_plan/expressions/apply.rs b/polars/polars-lazy/src/physical_plan/expressions/apply.rs index 7847c348119c..ee64751ae7ed 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/apply.rs @@ -25,6 +25,7 @@ pub struct ApplyExpr { pub pass_name_to_apply: bool, pub input_schema: Option, pub allow_threading: bool, + pub check_lengths: bool, } impl ApplyExpr { @@ -44,6 +45,7 @@ impl ApplyExpr { pass_name_to_apply: false, input_schema: None, allow_threading: true, + check_lengths: true, } } @@ -311,7 +313,12 @@ impl PhysicalExpr for ApplyExpr { { self.apply_multiple_group_aware(acs, df) } else { - apply_multiple_elementwise(acs, self.function.as_ref(), &self.expr) + apply_multiple_elementwise( + acs, + self.function.as_ref(), + &self.expr, + self.check_lengths, + ) } } } @@ -350,6 +357,7 @@ fn apply_multiple_elementwise<'a>( mut acs: Vec>, function: &dyn SeriesUdf, expr: &Expr, + check_lengths: bool, ) -> PolarsResult> { match acs.first().unwrap().agg_state() { // a fast path that doesn't drop groups of the first arg @@ -388,9 +396,11 @@ fn apply_multiple_elementwise<'a>( }) .collect::>(); - let input_len = s.iter().map(|s| s.len()).max().unwrap(); + let input_len = s[0].len(); let s = function.call_udf(&mut s)?.unwrap(); - check_map_output_len(input_len, s.len(), expr)?; + if check_lengths { + check_map_output_len(input_len, s.len(), expr)?; + } // take the first aggregation context that as that is the input series let mut ac = acs.swap_remove(0); diff --git a/polars/polars-lazy/src/physical_plan/planner/expr.rs b/polars/polars-lazy/src/physical_plan/planner/expr.rs index d7607b325212..dba817d10116 100644 --- a/polars/polars-lazy/src/physical_plan/planner/expr.rs +++ b/polars/polars-lazy/src/physical_plan/planner/expr.rs @@ -463,6 +463,7 @@ pub(crate) fn create_physical_expr( pass_name_to_apply: options.pass_name_to_apply, input_schema: schema.cloned(), allow_threading: !state.has_cache, + check_lengths: options.check_lengths(), })) } Function { @@ -497,6 +498,7 @@ pub(crate) fn create_physical_expr( pass_name_to_apply: options.pass_name_to_apply, input_schema: schema.cloned(), allow_threading: !state.has_cache, + check_lengths: options.check_lengths(), })) } Slice { diff --git a/py-polars/tests/unit/test_empty.py b/py-polars/tests/unit/test_empty.py index 298a28ec8697..3fa0bcda9f7b 100644 --- a/py-polars/tests/unit/test_empty.py +++ b/py-polars/tests/unit/test_empty.py @@ -60,3 +60,13 @@ def test_empty_sort_by_args() -> None: df = pl.DataFrame([1, 2, 3]) with pytest.raises(pl.InvalidOperationError): df.select(pl.all().sort_by([])) + + +def test_empty_9137() -> None: + out = ( + pl.DataFrame({"id": [], "value": []}) + .groupby("id") + .agg(pl.col("value").pow(2).mean()) + ) + assert out.shape == (0, 2) + assert out.dtypes == [pl.Float32, pl.Float32] diff --git a/py-polars/tests/unit/test_expr_multi_cols.py b/py-polars/tests/unit/test_expr_multi_cols.py index d533c49b6997..d07ed7a492e4 100644 --- a/py-polars/tests/unit/test_expr_multi_cols.py +++ b/py-polars/tests/unit/test_expr_multi_cols.py @@ -81,3 +81,20 @@ def test_append_root_columns() -> None: ] ) ).columns == ["col2", "col1", "prefix_col1", "col1_suffix"] + + +def test_multiple_columns_length_9137() -> None: + df = pl.DataFrame( + { + "a": [1, 1], + "b": ["c", "d"], + } + ) + + # list is larger than groups + cmp_list = ["a", "b", "c"] + + assert df.groupby("a").agg(pl.col("b").is_in(cmp_list)).to_dict(False) == { + "a": [1], + "b": [[True, False]], + }