Skip to content

Commit

Permalink
refactor: outliers to cache antimodes and unify CSV setup for both …
Browse files Browse the repository at this point in the history
…modes
  • Loading branch information
jqnatividad committed Jan 1, 2025
1 parent aee7bee commit f95cd1e
Showing 1 changed file with 216 additions and 46 deletions.
262 changes: 216 additions & 46 deletions src/cmd/outliers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ use crate::{
CliResult,
};

use std::collections::HashSet;
use std::sync::{Mutex, OnceLock};

static ANTIMODE_CACHE: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();

// Helper function to get or create cached antimodes
fn get_cached_antimodes(antimode: &str) -> HashSet<String> {
let cache = ANTIMODE_CACHE.get_or_init(|| Mutex::new(HashSet::new()));
let mut cache = cache.lock().unwrap();
if cache.is_empty() {
cache.extend(
antimode
.split('|')
.map(String::from)
.collect::<HashSet<_>>()
);
}
cache.clone()
}

#[derive(Deserialize)]
struct Args {
cmd_remove: bool,
Expand Down Expand Up @@ -89,31 +109,23 @@ fn is_outlier(value: f64, lower_fence: f64, upper_fence: f64) -> bool {
value < lower_fence || value > upper_fence
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;

// Get stats records (we still need these for the fences/thresholds)
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: args.flag_force,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: args.arg_input.clone(),
flag_memcheck: false,
};
let (_csv_fields, csv_stats) = get_stats_records(&schema_args, StatsMode::Outliers)?;
struct CsvSetup {
reader: csv::Reader<Box<dyn io::Read + Send>>,
writer: csv::Writer<Box<dyn io::Write>>,
headers: ByteRecord,
selected_stats: Vec<StatsData>,
progress_bar: Option<ProgressBar>,
}

fn setup_csv(
args: &Args,
csv_stats: &[StatsData],
write_outlier_headers: bool,
) -> CliResult<CsvSetup> {
// Setup CSV reader with selection
let rconfig = Config::new(args.arg_input.as_ref())
.delimiter(args.flag_delimiter)
.select(args.flag_select);
.select(args.flag_select.clone());
let mut rdr = rconfig.reader()?;

// Get headers and create selection
Expand All @@ -137,20 +149,24 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
.delimiter(args.flag_delimiter.unwrap_or(Delimiter(b',')).0)
.from_writer(wtr);

// Write CSV headers
csv_wtr.write_record([
"column",
"data_type",
"value",
"record_number",
"fence_type",
"reason",
"lower_fence",
"upper_fence",
])?;

// Setup progress bar if not quiet
let pb = if args.flag_quiet {
// Write headers based on mode
if write_outlier_headers {
csv_wtr.write_record([
"column",
"data_type",
"value",
"record_number",
"fence_type",
"reason",
"lower_fence",
"upper_fence",
])?;
} else {
csv_wtr.write_record(&headers)?;
}

// Setup progress bar
let progress_bar = if args.flag_quiet {
None
} else {
let pb = ProgressBar::new_spinner();
Expand All @@ -162,45 +178,197 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
Some(pb)
};

// Process records one at a time
Ok(CsvSetup {
reader: rdr,
writer: csv_wtr,
headers,
selected_stats,
progress_bar,
})
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let args: Args = util::get_args(USAGE, argv)?;

// Get stats records (we still need these for the fences/thresholds)
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
flag_strict_dates: false,
flag_pattern_columns: crate::select::SelectColumns::parse("").unwrap(),
flag_dates_whitelist: String::new(),
flag_prefer_dmy: false,
flag_force: args.flag_force,
flag_stdout: false,
flag_jobs: None,
flag_no_headers: false,
flag_delimiter: args.flag_delimiter,
arg_input: args.arg_input.clone(),
flag_memcheck: false,
};
let (_csv_fields, csv_stats) = get_stats_records(&schema_args, StatsMode::Outliers)?;
eprintln!("csv_stats: {:#?}", csv_stats);

if args.cmd_remove {
remove_outliers(&args, &csv_stats)
} else {
identify_outliers(&args, &csv_stats)
}
}

// New function to handle the remove subcommand
fn remove_outliers(args: &Args, csv_stats: &[StatsData]) -> CliResult<()> {
let mut setup = setup_csv(args, csv_stats, false)?;
let method = FenceType::from_str(args.flag_method.as_deref().unwrap_or("outer"));
let mut record = ByteRecord::new();
let mut record_count = 0u64;
let mut removed_count = 0u64;

while setup.reader.read_byte_record(&mut record)? {
record_count += 1;
if let Some(pb) = &setup.progress_bar {
pb.set_position(record_count);
}

let mut is_outlier = false;

// Check each selected column for outliers
for (col_idx, stat) in setup.selected_stats.iter().enumerate() {
let field = record.get(col_idx).unwrap_or_default();

match stat.r#type.as_str() {
"Integer" | "Float" => {
if let Some(val) = str::from_utf8(field)
.ok()
.and_then(|s| s.parse::<f64>().ok())
{
is_outlier |= is_numeric_outlier(val, stat, &method);
}
},
"String" => {
if let Ok(val) = str::from_utf8(field) {
is_outlier |= is_string_outlier(val, stat);
}
},
_ => {},
}

if is_outlier {
break; // No need to check other columns if we found an outlier
}
}

// Write record only if it's not an outlier
if is_outlier {
removed_count += 1;
} else {
setup.writer.write_record(&record)?;
}
}

if let Some(pb) = &setup.progress_bar {
pb.finish_with_message(format!(
"Processed {record_count} records, removed {removed_count} outliers"
));
}

setup.writer.flush()?;
Ok(())
}

// New helper function for checking numeric outliers without writing
fn is_numeric_outlier(value: f64, stat: &StatsData, method: &FenceType) -> bool {
if let (Some(lower_inner), Some(upper_inner), Some(lower_outer), Some(upper_outer)) = (
stat.lower_inner_fence,
stat.upper_inner_fence,
stat.lower_outer_fence,
stat.upper_outer_fence,
) {
let (is_inner, is_outer) = (
is_outlier(value, lower_inner, upper_inner),
is_outlier(value, lower_outer, upper_outer),
);

match method {
FenceType::Inner => is_inner,
FenceType::Outer => is_outer,
FenceType::Both => is_inner || is_outer,
}
} else {
false
}
}

// Helper function for checking string outliers
fn is_string_outlier(value: &str, stat: &StatsData) -> bool {
// Check string length outliers
if let (Some(mean_len), Some(stddev_len)) = (stat.avg_length, stat.stddev_length) {
#[allow(clippy::cast_precision_loss)]
let len = value.len() as f64;
let z_score = (len - mean_len) / stddev_len;
if z_score.abs() > 3.0 {
return true;
}
}

// Check rare categories with cached antimodes
if let Some(ref antimode) = stat.antimode {
if !antimode.starts_with("*ALL") {
let cached_antimodes = get_cached_antimodes(antimode);
if cached_antimodes.contains(value) {
return true;
}
}
}

false
}

fn identify_outliers(args: &Args, csv_stats: &[StatsData]) -> CliResult<()> {
let mut setup = setup_csv(args, csv_stats, true)?;
let method = FenceType::from_str(args.flag_method.as_deref().unwrap_or("outer"));
let mut record = ByteRecord::new();
let mut record_count = 0u64;

while rdr.read_byte_record(&mut record)? {
while setup.reader.read_byte_record(&mut record)? {
record_count += 1;
if let Some(pb) = &pb {
if let Some(pb) = &setup.progress_bar {
pb.set_position(record_count);
}

// Process each selected column
for (col_idx, stat) in selected_stats.iter().enumerate() {
let field = record.get(sel[col_idx]).unwrap_or_default();
for stat in &setup.selected_stats {
let col_idx = setup
.headers
.iter()
.position(|h| h == stat.field.as_bytes())
.unwrap_or(0);
let field = record.get(col_idx).unwrap_or_default();

match stat.r#type.as_str() {
"Integer" | "Float" => {
if let Some(val) = str::from_utf8(field)
.ok()
.and_then(|s| s.parse::<f64>().ok())
{
check_numeric_outlier(val, stat, &method, record_count, &mut csv_wtr)?;
check_numeric_outlier(val, stat, &method, record_count, &mut setup.writer)?;
}
},
"String" => {
if let Ok(val) = str::from_utf8(field) {
check_string_outlier(val, stat, record_count, &mut csv_wtr)?;
check_string_outlier(val, stat, record_count, &mut setup.writer)?;
}
},
_ => {},
}
}
}

if let Some(pb) = &pb {
if let Some(pb) = &setup.progress_bar {
pb.finish_with_message(format!("Processed {record_count} records"));
}

csv_wtr.flush()?;
setup.writer.flush()?;
Ok(())
}

Expand Down Expand Up @@ -262,7 +430,10 @@ fn check_string_outlier(
) -> CliResult<()> {
// Check string length outliers
if let (Some(mean_len), Some(stddev_len)) = (stat.avg_length, stat.stddev_length) {
println!("mean_len: {mean_len}, stddev_len: {stddev_len} value_len: {}", value.len());
eprintln!(
"mean_len: {mean_len}, stddev_len: {stddev_len} value_len: {}",
value.len()
);
#[allow(clippy::cast_precision_loss)]
let len = value.len() as f64;
let z_score = (len - mean_len) / stddev_len;
Expand Down Expand Up @@ -298,7 +469,6 @@ fn check_string_outlier(
])?;
}
}

}
Ok(())
}

0 comments on commit f95cd1e

Please sign in to comment.