diff --git a/src/cmd/outliers.rs b/src/cmd/outliers.rs index 88282b626..ed3be1cfd 100644 --- a/src/cmd/outliers.rs +++ b/src/cmd/outliers.rs @@ -54,6 +54,26 @@ use crate::{ CliResult, }; +use std::collections::HashSet; +use std::sync::{Mutex, OnceLock}; + +static ANTIMODE_CACHE: OnceLock>> = OnceLock::new(); + +// Helper function to get or create cached antimodes +fn get_cached_antimodes(antimode: &str) -> HashSet { + let cache = ANTIMODE_CACHE.get_or_init(|| Mutex::new(HashSet::new())); + let mut cache = cache.lock().unwrap(); + if cache.is_empty() { + cache.extend( + antimode + .split('|') + .map(String::from) + .collect::>() + ); + } + cache.clone() +} + #[derive(Deserialize)] struct Args { cmd_remove: bool, @@ -89,31 +109,23 @@ fn is_outlier(value: f64, lower_fence: f64, upper_fence: f64) -> bool { value < lower_fence || value > upper_fence } -pub fn run(argv: &[&str]) -> CliResult<()> { - let args: Args = util::get_args(USAGE, argv)?; - - // Get stats records (we still need these for the fences/thresholds) - let schema_args = util::SchemaArgs { - flag_enum_threshold: 0, - flag_ignore_case: false, - flag_strict_dates: false, - flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(), - flag_dates_whitelist: String::new(), - flag_prefer_dmy: false, - flag_force: args.flag_force, - flag_stdout: false, - flag_jobs: None, - flag_no_headers: false, - flag_delimiter: args.flag_delimiter, - arg_input: args.arg_input.clone(), - flag_memcheck: false, - }; - let (_csv_fields, csv_stats) = get_stats_records(&schema_args, StatsMode::Outliers)?; +struct CsvSetup { + reader: csv::Reader>, + writer: csv::Writer>, + headers: ByteRecord, + selected_stats: Vec, + progress_bar: Option, +} +fn setup_csv( + args: &Args, + csv_stats: &[StatsData], + write_outlier_headers: bool, +) -> CliResult { // Setup CSV reader with selection let rconfig = Config::new(args.arg_input.as_ref()) .delimiter(args.flag_delimiter) - .select(args.flag_select); + .select(args.flag_select.clone()); let mut rdr = rconfig.reader()?; // Get headers and create selection @@ -137,20 +149,24 @@ pub fn run(argv: &[&str]) -> CliResult<()> { .delimiter(args.flag_delimiter.unwrap_or(Delimiter(b',')).0) .from_writer(wtr); - // Write CSV headers - csv_wtr.write_record([ - "column", - "data_type", - "value", - "record_number", - "fence_type", - "reason", - "lower_fence", - "upper_fence", - ])?; - - // Setup progress bar if not quiet - let pb = if args.flag_quiet { + // Write headers based on mode + if write_outlier_headers { + csv_wtr.write_record([ + "column", + "data_type", + "value", + "record_number", + "fence_type", + "reason", + "lower_fence", + "upper_fence", + ])?; + } else { + csv_wtr.write_record(&headers)?; + } + + // Setup progress bar + let progress_bar = if args.flag_quiet { None } else { let pb = ProgressBar::new_spinner(); @@ -162,20 +178,172 @@ pub fn run(argv: &[&str]) -> CliResult<()> { Some(pb) }; - // Process records one at a time + Ok(CsvSetup { + reader: rdr, + writer: csv_wtr, + headers, + selected_stats, + progress_bar, + }) +} + +pub fn run(argv: &[&str]) -> CliResult<()> { + let args: Args = util::get_args(USAGE, argv)?; + + // Get stats records (we still need these for the fences/thresholds) + let schema_args = util::SchemaArgs { + flag_enum_threshold: 0, + flag_ignore_case: false, + flag_strict_dates: false, + flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(), + flag_dates_whitelist: String::new(), + flag_prefer_dmy: false, + flag_force: args.flag_force, + flag_stdout: false, + flag_jobs: None, + flag_no_headers: false, + flag_delimiter: args.flag_delimiter, + arg_input: args.arg_input.clone(), + flag_memcheck: false, + }; + let (_csv_fields, csv_stats) = get_stats_records(&schema_args, StatsMode::Outliers)?; + eprintln!("csv_stats: {:#?}", csv_stats); + + if args.cmd_remove { + remove_outliers(&args, &csv_stats) + } else { + identify_outliers(&args, &csv_stats) + } +} + +// New function to handle the remove subcommand +fn remove_outliers(args: &Args, csv_stats: &[StatsData]) -> CliResult<()> { + let mut setup = setup_csv(args, csv_stats, false)?; + let method = FenceType::from_str(args.flag_method.as_deref().unwrap_or("outer")); + let mut record = ByteRecord::new(); + let mut record_count = 0u64; + let mut removed_count = 0u64; + + while setup.reader.read_byte_record(&mut record)? { + record_count += 1; + if let Some(pb) = &setup.progress_bar { + pb.set_position(record_count); + } + + let mut is_outlier = false; + + // Check each selected column for outliers + for (col_idx, stat) in setup.selected_stats.iter().enumerate() { + let field = record.get(col_idx).unwrap_or_default(); + + match stat.r#type.as_str() { + "Integer" | "Float" => { + if let Some(val) = str::from_utf8(field) + .ok() + .and_then(|s| s.parse::().ok()) + { + is_outlier |= is_numeric_outlier(val, stat, &method); + } + }, + "String" => { + if let Ok(val) = str::from_utf8(field) { + is_outlier |= is_string_outlier(val, stat); + } + }, + _ => {}, + } + + if is_outlier { + break; // No need to check other columns if we found an outlier + } + } + + // Write record only if it's not an outlier + if is_outlier { + removed_count += 1; + } else { + setup.writer.write_record(&record)?; + } + } + + if let Some(pb) = &setup.progress_bar { + pb.finish_with_message(format!( + "Processed {record_count} records, removed {removed_count} outliers" + )); + } + + setup.writer.flush()?; + Ok(()) +} + +// New helper function for checking numeric outliers without writing +fn is_numeric_outlier(value: f64, stat: &StatsData, method: &FenceType) -> bool { + if let (Some(lower_inner), Some(upper_inner), Some(lower_outer), Some(upper_outer)) = ( + stat.lower_inner_fence, + stat.upper_inner_fence, + stat.lower_outer_fence, + stat.upper_outer_fence, + ) { + let (is_inner, is_outer) = ( + is_outlier(value, lower_inner, upper_inner), + is_outlier(value, lower_outer, upper_outer), + ); + + match method { + FenceType::Inner => is_inner, + FenceType::Outer => is_outer, + FenceType::Both => is_inner || is_outer, + } + } else { + false + } +} + +// Helper function for checking string outliers +fn is_string_outlier(value: &str, stat: &StatsData) -> bool { + // Check string length outliers + if let (Some(mean_len), Some(stddev_len)) = (stat.avg_length, stat.stddev_length) { + #[allow(clippy::cast_precision_loss)] + let len = value.len() as f64; + let z_score = (len - mean_len) / stddev_len; + if z_score.abs() > 3.0 { + return true; + } + } + + // Check rare categories with cached antimodes + if let Some(ref antimode) = stat.antimode { + if !antimode.starts_with("*ALL") { + let cached_antimodes = get_cached_antimodes(antimode); + if cached_antimodes.contains(value) { + return true; + } + } + } + + false +} + +fn identify_outliers(args: &Args, csv_stats: &[StatsData]) -> CliResult<()> { + let mut setup = setup_csv(args, csv_stats, true)?; let method = FenceType::from_str(args.flag_method.as_deref().unwrap_or("outer")); let mut record = ByteRecord::new(); let mut record_count = 0u64; - while rdr.read_byte_record(&mut record)? { + while setup.reader.read_byte_record(&mut record)? { record_count += 1; - if let Some(pb) = &pb { + if let Some(pb) = &setup.progress_bar { pb.set_position(record_count); } // Process each selected column - for (col_idx, stat) in selected_stats.iter().enumerate() { - let field = record.get(sel[col_idx]).unwrap_or_default(); + for stat in &setup.selected_stats { + let col_idx = setup + .headers + .iter() + .position(|h| h == stat.field.as_bytes()) + .unwrap_or(0); + let field = record.get(col_idx).unwrap_or_default(); match stat.r#type.as_str() { "Integer" | "Float" => { @@ -183,12 +351,12 @@ pub fn run(argv: &[&str]) -> CliResult<()> { .ok() .and_then(|s| s.parse::().ok()) { - check_numeric_outlier(val, stat, &method, record_count, &mut csv_wtr)?; + check_numeric_outlier(val, stat, &method, record_count, &mut setup.writer)?; } }, "String" => { if let Ok(val) = str::from_utf8(field) { - check_string_outlier(val, stat, record_count, &mut csv_wtr)?; + check_string_outlier(val, stat, record_count, &mut setup.writer)?; } }, _ => {}, @@ -196,11 +364,11 @@ pub fn run(argv: &[&str]) -> CliResult<()> { } } - if let Some(pb) = &pb { + if let Some(pb) = &setup.progress_bar { pb.finish_with_message(format!("Processed {record_count} records")); } - csv_wtr.flush()?; + setup.writer.flush()?; Ok(()) } @@ -262,7 +430,10 @@ fn check_string_outlier( ) -> CliResult<()> { // Check string length outliers if let (Some(mean_len), Some(stddev_len)) = (stat.avg_length, stat.stddev_length) { - println!("mean_len: {mean_len}, stddev_len: {stddev_len} value_len: {}", value.len()); + eprintln!( + "mean_len: {mean_len}, stddev_len: {stddev_len} value_len: {}", + value.len() + ); #[allow(clippy::cast_precision_loss)] let len = value.len() as f64; let z_score = (len - mean_len) / stddev_len; @@ -298,7 +469,6 @@ fn check_string_outlier( ])?; } } - } Ok(()) }