Skip to content

Commit

Permalink
fix(rust, python): let apply caller determine if length needs to be…
Browse files Browse the repository at this point in the history
… checked. (pola-rs#9140)
  • Loading branch information
ritchie46 authored and c-peters committed Jul 14, 2023
1 parent 22cd2e3 commit 2388246
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 8 deletions.
15 changes: 10 additions & 5 deletions polars/polars-lazy/polars-plan/src/logical_plan/optimizer/fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ use super::*;
pub struct FusedArithmetic {}

fn get_expr(input: Vec<Node>, op: FusedOperator) -> AExpr {
let mut options = FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
cast_to_supertypes: true,
..Default::default()
};
// order of operations change because of FMA
// so we must toggle this check off
// it is still safe as it is a trusted operation
unsafe { options.no_check_lengths() }
AExpr::Function {
input,
function: FunctionExpr::Fused(op),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
cast_to_supertypes: true,
..Default::default()
},
options,
}
}

Expand Down
22 changes: 22 additions & 0 deletions polars/polars-lazy/polars-plan/src/logical_plan/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ pub enum ApplyOptions {
ApplyFlat,
}

// a boolean that can only be set to `false` safely
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct UnsafeBool(bool);
impl Default for UnsafeBool {
fn default() -> Self {
UnsafeBool(true)
}
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct FunctionOptions {
Expand Down Expand Up @@ -217,6 +227,9 @@ pub struct FunctionOptions {
pub pass_name_to_apply: bool,
// For example a `unique` or a `slice`
pub changes_length: bool,
// Validate the output of a `map`.
// this should always be true or we could OOB
pub check_lengths: UnsafeBool,
}

impl FunctionOptions {
Expand All @@ -227,6 +240,14 @@ impl FunctionOptions {
pub fn is_groups_sensitive(&self) -> bool {
matches!(self.collect_groups, ApplyOptions::ApplyGroups)
}

#[cfg(feature = "fused")]
pub(crate) unsafe fn no_check_lengths(&mut self) {
self.check_lengths = UnsafeBool(false);
}
pub fn check_lengths(&self) -> bool {
self.check_lengths.0
}
}

impl Default for FunctionOptions {
Expand All @@ -240,6 +261,7 @@ impl Default for FunctionOptions {
allow_rename: false,
pass_name_to_apply: false,
changes_length: false,
check_lengths: UnsafeBool(true),
}
}
}
Expand Down
16 changes: 13 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct ApplyExpr {
pub pass_name_to_apply: bool,
pub input_schema: Option<SchemaRef>,
pub allow_threading: bool,
pub check_lengths: bool,
}

impl ApplyExpr {
Expand All @@ -44,6 +45,7 @@ impl ApplyExpr {
pass_name_to_apply: false,
input_schema: None,
allow_threading: true,
check_lengths: true,
}
}

Expand Down Expand Up @@ -311,7 +313,12 @@ impl PhysicalExpr for ApplyExpr {
{
self.apply_multiple_group_aware(acs, df)
} else {
apply_multiple_elementwise(acs, self.function.as_ref(), &self.expr)
apply_multiple_elementwise(
acs,
self.function.as_ref(),
&self.expr,
self.check_lengths,
)
}
}
}
Expand Down Expand Up @@ -350,6 +357,7 @@ fn apply_multiple_elementwise<'a>(
mut acs: Vec<AggregationContext<'a>>,
function: &dyn SeriesUdf,
expr: &Expr,
check_lengths: bool,
) -> PolarsResult<AggregationContext<'a>> {
match acs.first().unwrap().agg_state() {
// a fast path that doesn't drop groups of the first arg
Expand Down Expand Up @@ -388,9 +396,11 @@ fn apply_multiple_elementwise<'a>(
})
.collect::<Vec<_>>();

let input_len = s.iter().map(|s| s.len()).max().unwrap();
let input_len = s[0].len();
let s = function.call_udf(&mut s)?.unwrap();
check_map_output_len(input_len, s.len(), expr)?;
if check_lengths {
check_map_output_len(input_len, s.len(), expr)?;
}

// take the first aggregation context that as that is the input series
let mut ac = acs.swap_remove(0);
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ pub(crate) fn create_physical_expr(
pass_name_to_apply: options.pass_name_to_apply,
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
}))
}
Function {
Expand Down Expand Up @@ -497,6 +498,7 @@ pub(crate) fn create_physical_expr(
pass_name_to_apply: options.pass_name_to_apply,
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
}))
}
Slice {
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/test_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ def test_empty_sort_by_args() -> None:
df = pl.DataFrame([1, 2, 3])
with pytest.raises(pl.InvalidOperationError):
df.select(pl.all().sort_by([]))


def test_empty_9137() -> None:
out = (
pl.DataFrame({"id": [], "value": []})
.groupby("id")
.agg(pl.col("value").pow(2).mean())
)
assert out.shape == (0, 2)
assert out.dtypes == [pl.Float32, pl.Float32]
17 changes: 17 additions & 0 deletions py-polars/tests/unit/test_expr_multi_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,20 @@ def test_append_root_columns() -> None:
]
)
).columns == ["col2", "col1", "prefix_col1", "col1_suffix"]


def test_multiple_columns_length_9137() -> None:
df = pl.DataFrame(
{
"a": [1, 1],
"b": ["c", "d"],
}
)

# list is larger than groups
cmp_list = ["a", "b", "c"]

assert df.groupby("a").agg(pl.col("b").is_in(cmp_list)).to_dict(False) == {
"a": [1],
"b": [[True, False]],
}

0 comments on commit 2388246

Please sign in to comment.