Skip to content

Commit

Permalink
Fix rust test for logical plan optimizer for categoricals pola-rs#9828
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiza committed Sep 15, 2023
1 parent 07b9033 commit 83f4f01
Showing 1 changed file with 52 additions and 24 deletions.
76 changes: 52 additions & 24 deletions crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,28 +519,56 @@ fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> {
}
}

// TODO: Fix this test and re-enable it (currently does not compile)
// #[cfg(test)]
#[cfg(test)]
// #[cfg(feature = "dtype-categorical")]
// mod test {
// use polars_core::prelude::*;

// use super::*;
// use crate::prelude::*;

// #[test]
// fn test_categorical_utf8() {
// let mut rules: Vec<Box<dyn OptimizationRule>> = vec![Box::new(TypeCoercionRule {})];
// let schema = Schema::from_iter([Field::new("fruits", DataType::Categorical(None))]);

// let expr = col("fruits").eq(lit("somestr"));
// let out = optimize_expr(expr.clone(), schema.clone(), &mut rules);
// // we test that the fruits column is not casted to utf8 for the comparison
// assert_eq!(out, expr);

// let expr = col("fruits") + (lit("somestr"));
// let out = optimize_expr(expr, schema, &mut rules);
// let expected = col("fruits").cast(DataType::Utf8) + lit("somestr");
// assert_eq!(out, expected);
// }
// }
mod test {
use polars_core::prelude::*;

use super::*;

#[test]
fn test_categorical_utf8() {
let mut expr_arena = Arena::new();
let mut lp_arena = Arena::new();
let optimizer = StackOptimizer {};
let rules: &mut [Box<dyn OptimizationRule>] = &mut [Box::new(TypeCoercionRule {})];

let df = DataFrame::new(Vec::from([Series::new_empty(
"fruits",
&DataType::Categorical(None),
)]))
.unwrap();

let expr_in = vec![col("fruits").eq(lit("somestr"))];
let lp = LogicalPlanBuilder::from_existing_df(df.clone())
.project(expr_in.clone(), Default::default())
.build();

let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap();
lp_top = optimizer
.optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top)
.unwrap();
let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena);

// we test that the fruits column is not casted to utf8 for the comparison
if let LogicalPlan::Projection { expr, .. } = lp {
assert_eq!(expr, expr_in);
};

let expr_in = vec![col("fruits") + (lit("somestr"))];
let lp = LogicalPlanBuilder::from_existing_df(df)
.project(expr_in, Default::default())
.build();
let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap();
lp_top = optimizer
.optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top)
.unwrap();
let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena);

// we test that the fruits column is casted to utf8 for the addition
let expected = vec![col("fruits").cast(DataType::Utf8) + lit("somestr")];
if let LogicalPlan::Projection { expr, .. } = lp {
assert_eq!(expr, expected);
};
}
}

0 comments on commit 83f4f01

Please sign in to comment.