From 83f4f01eca6ee257eac460f4a9325f5b021957be Mon Sep 17 00:00:00 2001 From: ptiza Date: Fri, 15 Sep 2023 15:12:31 +0200 Subject: [PATCH] Fix rust test for logical plan optimizer for categoricals #9828 --- .../optimizer/type_coercion/mod.rs | 76 +++++++++++++------ 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index ba92141239fe..7002bed4e95a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -519,28 +519,56 @@ fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { } } -// TODO: Fix this test and re-enable it (currently does not compile) -// #[cfg(test)] +#[cfg(test)] // #[cfg(feature = "dtype-categorical")] -// mod test { -// use polars_core::prelude::*; - -// use super::*; -// use crate::prelude::*; - -// #[test] -// fn test_categorical_utf8() { -// let mut rules: Vec> = vec![Box::new(TypeCoercionRule {})]; -// let schema = Schema::from_iter([Field::new("fruits", DataType::Categorical(None))]); - -// let expr = col("fruits").eq(lit("somestr")); -// let out = optimize_expr(expr.clone(), schema.clone(), &mut rules); -// // we test that the fruits column is not casted to utf8 for the comparison -// assert_eq!(out, expr); - -// let expr = col("fruits") + (lit("somestr")); -// let out = optimize_expr(expr, schema, &mut rules); -// let expected = col("fruits").cast(DataType::Utf8) + lit("somestr"); -// assert_eq!(out, expected); -// } -// } +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_categorical_utf8() { + let mut expr_arena = Arena::new(); + let mut lp_arena = Arena::new(); + let optimizer = StackOptimizer {}; + let rules: &mut [Box] = &mut [Box::new(TypeCoercionRule {})]; + + let df = DataFrame::new(Vec::from([Series::new_empty( + "fruits", + &DataType::Categorical(None), + )])) + .unwrap(); + + let expr_in = vec![col("fruits").eq(lit("somestr"))]; + let lp = LogicalPlanBuilder::from_existing_df(df.clone()) + .project(expr_in.clone(), Default::default()) + .build(); + + let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is not casted to utf8 for the comparison + if let LogicalPlan::Projection { expr, .. } = lp { + assert_eq!(expr, expr_in); + }; + + let expr_in = vec![col("fruits") + (lit("somestr"))]; + let lp = LogicalPlanBuilder::from_existing_df(df) + .project(expr_in, Default::default()) + .build(); + let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is casted to utf8 for the addition + let expected = vec![col("fruits").cast(DataType::Utf8) + lit("somestr")]; + if let LogicalPlan::Projection { expr, .. } = lp { + assert_eq!(expr, expected); + }; + } +}