Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Always expand horizontal_any/all #15816

Merged
merged 3 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::*;
fn cached_before_root(q: LazyFrame) {
let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();
for input in lp_arena.get(lp).get_inputs() {
for input in lp_arena.get(lp).get_inputs_vec() {
assert!(matches!(lp_arena.get(input), IR::Cache { .. }));
}
}
Expand Down
40 changes: 0 additions & 40 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,5 @@
use std::ops::{BitAnd, BitOr};

use polars_core::frame::NullStrategy;
use polars_core::prelude::*;
use polars_core::POOL;
use rayon::prelude::*;

pub fn any_horizontal(s: &[Series]) -> PolarsResult<Series> {
let out = POOL
.install(|| {
s.par_iter()
.try_fold(
|| BooleanChunked::new("", &[false]),
|acc, b| {
let b = b.cast(&DataType::Boolean)?;
let b = b.bool()?;
PolarsResult::Ok((&acc).bitor(b))
},
)
.try_reduce(|| BooleanChunked::new("", [false]), |a, b| Ok(a.bitor(b)))
})?
.with_name(s[0].name());
Ok(out.into_series())
}

pub fn all_horizontal_impl(s: &[Series]) -> PolarsResult<Series> {
let out = POOL
.install(|| {
s.par_iter()
.try_fold(
|| BooleanChunked::new("", &[true]),
|acc, b| {
let b = b.cast(&DataType::Boolean)?;
let b = b.bool()?;
PolarsResult::Ok((&acc).bitand(b))
},
)
.try_reduce(|| BooleanChunked::new("", [true]), |a, b| Ok(a.bitand(b)))
})?
.with_name(s[0].name());
Ok(out.into_series())
}

pub fn max_horizontal(s: &[Series]) -> PolarsResult<Option<Series>> {
let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) };
Expand Down
15 changes: 4 additions & 11 deletions crates/polars-plan/src/dsl/function_expr/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::*;
use crate::map;
#[cfg(feature = "is_between")]
use crate::map_as_slice;
#[cfg(feature = "is_in")]
use crate::wrap;
use crate::{map, map_as_slice};

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
Expand Down Expand Up @@ -112,9 +114,8 @@ impl From<BooleanFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
IsBetween { closed } => map_as_slice!(is_between, closed),
#[cfg(feature = "is_in")]
IsIn => wrap!(is_in),
AllHorizontal => map_as_slice!(all_horizontal),
AnyHorizontal => map_as_slice!(any_horizontal),
Not => map!(not),
AllHorizontal | AnyHorizontal => unreachable!(),
}
}
}
Expand Down Expand Up @@ -202,14 +203,6 @@ fn is_in(s: &mut [Series]) -> PolarsResult<Option<Series>> {
polars_ops::prelude::is_in(left, other).map(|ca| Some(ca.into_series()))
}

fn any_horizontal(s: &[Series]) -> PolarsResult<Series> {
polars_ops::prelude::any_horizontal(s)
}

fn all_horizontal(s: &[Series]) -> PolarsResult<Series> {
polars_ops::prelude::all_horizontal_impl(s)
}

fn not(s: &Series) -> PolarsResult<Series> {
polars_ops::series::negate_bitwise(s)
}
27 changes: 2 additions & 25 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,26 +195,12 @@ where
pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
let exprs = exprs.as_ref().to_vec();
polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown");

// We prefer this path as the optimizer can better deal with the binary operations.
// However if we have a single expression, we might lose information.
// E.g. `all().is_null()` would reduce to `all().is_null()` (the & is not needed as there is no rhs (yet)
// And upon expansion, it becomes
// `col(i).is_null() for i in len(df))`
// so we would miss the boolean operator.
if exprs.len() > 1 {
return Ok(exprs.into_iter().reduce(|l, r| l.logical_and(r)).unwrap());
}

// This will be reduced to `expr & expr` during conversion to IR.
Ok(Expr::Function {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
},
})
Expand All @@ -226,21 +212,12 @@ pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
pub fn any_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
let exprs = exprs.as_ref().to_vec();
polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown");

// See comment in `all_horizontal`.
if exprs.len() > 1 {
return Ok(exprs.into_iter().reduce(|l, r| l.logical_or(r)).unwrap());
}

// This will be reduced to `expr | expr` during conversion to IR.
Ok(Expr::Function {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
},
})
Expand Down
17 changes: 9 additions & 8 deletions crates/polars-plan/src/logical_plan/alp/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,18 @@ impl IR {
container.push_node(input)
}

pub fn get_inputs(&self) -> Vec<Node> {
let mut inputs = Vec::new();
pub fn get_inputs(&self) -> UnitVec<Node> {
let mut inputs: UnitVec<Node> = unitvec!();
self.copy_inputs(&mut inputs);
inputs
}

pub fn get_inputs_vec(&self) -> Vec<Node> {
let mut inputs = vec![];
self.copy_inputs(&mut inputs);
inputs
}
/// panics if more than one input
#[cfg(any(
all(feature = "strings", feature = "concat_str"),
feature = "streaming",
feature = "fused"
))]

pub(crate) fn get_input(&self) -> Option<Node> {
let mut inputs: UnitVec<Node> = unitvec!();
self.copy_inputs(&mut inputs);
Expand Down
15 changes: 15 additions & 0 deletions crates/polars-plan/src/logical_plan/alp/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ impl IR {
}
}

pub fn input_schema<'a>(&'a self, arena: &'a Arena<IR>) -> Option<Cow<'a, SchemaRef>> {
use IR::*;
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options, .. } => &options.schema,
DataFrameScan { schema, .. } => schema,
Scan { file_info, .. } => &file_info.schema,
node => {
let input = node.get_input()?;
return Some(arena.get(input).schema(arena));
},
};
Some(Cow::Borrowed(schema))
}

/// Get the schema of the logical plan node.
pub fn schema<'a>(&'a self, arena: &'a Arena<IR>) -> Cow<'a, SchemaRef> {
use IR::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,27 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
function,
options,
} => {
match function {
// Convert to binary expression as the optimizer understands those.
FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_and(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
},
FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_or(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
},
_ => {},
}

let e = to_expr_irs(input, arena);

if state.output_name.is_none() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl<'a> PredicatePushDown<'a> {
expr_arena: &mut Arena<AExpr>,
has_projections: bool,
) -> PolarsResult<IR> {
let inputs = lp.get_inputs();
let inputs = lp.get_inputs_vec();
let exprs = lp.get_exprs();

if has_projections {
Expand Down
40 changes: 32 additions & 8 deletions crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ impl OptimizationRule for SimplifyExprRule {
&mut self,
expr_arena: &mut Arena<AExpr>,
expr_node: Node,
_lp_arena: &Arena<IR>,
_lp_node: Node,
lp_arena: &Arena<IR>,
lp_node: Node,
) -> PolarsResult<Option<AExpr>> {
let expr = expr_arena.get(expr_node).clone();

Expand All @@ -443,8 +443,8 @@ impl OptimizationRule for SimplifyExprRule {
#[cfg(all(feature = "strings", feature = "concat_str"))]
{
string_addition_to_linear_concat(
_lp_arena,
_lp_node,
lp_arena,
lp_node,
expr_arena,
*left,
*right,
Expand Down Expand Up @@ -595,19 +595,43 @@ impl OptimizationRule for SimplifyExprRule {
strict,
} => {
let input = expr_arena.get(*expr);
inline_cast(input, data_type, *strict)?
inline_or_prune_cast(input, data_type, *strict, lp_node, lp_arena, expr_arena)?
},
_ => None,
};
Ok(out)
}
}

fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult<Option<AExpr>> {
fn inline_or_prune_cast(
aexpr: &AExpr,
dtype: &DataType,
strict: bool,
lp_node: Node,
lp_arena: &Arena<IR>,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
if !dtype.is_known() {
return Ok(None);
}
let lv = match (input, dtype) {
let lv = match (aexpr, dtype) {
// PRUNE
(
AExpr::BinaryExpr {
op: Operator::LogicalOr | Operator::LogicalAnd,
..
},
_,
) => {
if let Some(schema) = lp_arena.get(lp_node).input_schema(lp_arena) {
let field = aexpr.to_field(&schema, Context::Default, expr_arena)?;
if field.dtype == *dtype {
return Ok(Some(aexpr.clone()));
}
}
return Ok(None);
},
// INLINE
(AExpr::Literal(lv), _) => match lv {
LiteralValue::Series(s) => {
let s = if strict {
Expand All @@ -622,7 +646,7 @@ fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult<Op
return Ok(None);
};
if dtype == &av.dtype() {
return Ok(Some(input.clone()));
return Ok(Some(aexpr.clone()));
}
match (av, dtype) {
// casting null always remains null
Expand Down
Loading