Skip to content

Commit

Permalink
Merge pull request #2428 from dathere/pivotp-smarter-aggregation-sugg…
Browse files Browse the repository at this point in the history
…estions

refactor: `pivotp` - smarter aggregation suggestions
  • Loading branch information
jqnatividad authored Jan 10, 2025
2 parents 93aacb7 + 29d3d67 commit 241b50d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 30 deletions.
164 changes: 139 additions & 25 deletions src/cmd/pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,20 @@ fn validate_pivot_operation(metadata: &PivotMetadata) -> CliResult<()> {

/// Suggest an appropriate aggregation function based on column statistics
#[allow(clippy::cast_precision_loss)]
fn suggest_agg_function(args: &Args, value_cols: &[String]) -> CliResult<Option<PivotAgg>> {
fn suggest_agg_function(
args: &Args,
on_cols: &[String],
index_cols: Option<&[String]>,
value_cols: &[String],
) -> CliResult<Option<PivotAgg>> {
// If multiple value columns, default to First
if value_cols.len() > 1 {
return Ok(Some(PivotAgg::First));
}

let quiet = args.flag_quiet;

// Get stats for all columns with enhanced statistics
let schema_args = util::SchemaArgs {
flag_enum_threshold: 0,
flag_ignore_case: false,
Expand All @@ -218,18 +231,67 @@ fn suggest_agg_function(args: &Args, value_cols: &[String]) -> CliResult<Option<
flag_memcheck: false,
};

// If multiple value columns, default to First
if value_cols.len() > 1 {
return Ok(Some(PivotAgg::First));
}

let quiet = args.flag_quiet;

let (csv_fields, csv_stats) = STATS_RECORDS.get_or_init(|| {
get_stats_records(&schema_args, StatsMode::FrequencyForceStats)
.unwrap_or_else(|_| (ByteRecord::new(), Vec::new()))
});

let rconfig = Config::new(Some(&args.arg_input));
let row_count = util::count_rows(&rconfig)? as u64;
// eprintln!("row_count: {}\nstats: {:#?}", row_count, csv_stats);

// Analyze pivot column characteristics
let mut high_cardinality_pivot = false;
let mut ordered_pivot = false; // Track if pivot columns are ordered
for on_col in on_cols {
if let Some(pos) = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == on_col)
{
let stats = &csv_stats[pos];

// Check cardinality ratio
if stats.cardinality as f64 / row_count as f64 > 0.5 {
high_cardinality_pivot = true;
if !quiet {
eprintln!("Info: Pivot column \"{on_col}\" has high cardinality");
}
}

// Check if column is unordered based on sort_order
if let Some(sort_order) = &stats.sort_order {
ordered_pivot = sort_order != "Unsorted";
}
}
}

// Analyze index column characteristics
let mut high_cardinality_index = false;
let mut ordered_index = false;
if let Some(idx_cols) = index_cols {
for idx_col in idx_cols {
if let Some(pos) = csv_fields
.iter()
.position(|f| std::str::from_utf8(f).unwrap_or("") == idx_col)
{
let stats = &csv_stats[pos];

// Check cardinality ratio
if stats.cardinality as f64 / row_count as f64 > 0.5 {
high_cardinality_index = true;
if !quiet {
eprintln!("Info: Index column \"{idx_col}\" has high cardinality");
}
}

// Check if column is unordered
if let Some(sort_order) = &stats.sort_order {
ordered_index = sort_order != "Unsorted";
}
}
}
}

// Get stats for the value column
let value_col = &value_cols[0];
let field_pos = csv_fields
Expand All @@ -238,8 +300,6 @@ fn suggest_agg_function(args: &Args, value_cols: &[String]) -> CliResult<Option<

if let Some(pos) = field_pos {
let stats = &csv_stats[pos];
let rconfig = Config::new(Some(&args.arg_input));
let row_count = util::count_rows(&rconfig)? as u64;

// Suggest aggregation based on field type and statistics
let suggested_agg = match stats.r#type.as_str() {
Expand All @@ -255,25 +315,69 @@ fn suggest_agg_function(args: &Args, value_cols: &[String]) -> CliResult<Option<
eprintln!("Info: \"{value_col}\" contains >50% NULL values, using Count");
}
PivotAgg::Count
} else if stats.cv > Some(1.0) {
// High coefficient of variation suggests using median for better central
// tendency
if !quiet {
eprintln!(
"Info: High variability in values (CV > 1), using Median for more \
robust central tendency"
);
}
PivotAgg::Median
} else if high_cardinality_pivot && high_cardinality_index {
if ordered_pivot && ordered_index {
// With ordered high cardinality columns, mean might be more meaningful
if !quiet {
eprintln!(
"Info: Ordered high cardinality columns detected, using Mean"
);
}
PivotAgg::Mean
} else {
// With unordered high cardinality, sum might be more appropriate
if !quiet {
eprintln!(
"Info: High cardinality in pivot and index columns, using Sum"
);
}
PivotAgg::Sum
}
} else if let Some(skewness) = stats.skewness {
if skewness.abs() > 2.0 {
// Highly skewed data might benefit from median
if !quiet {
eprintln!("Info: Highly skewed numeric data detected, using Median");
}
PivotAgg::Median
} else {
PivotAgg::Sum
}
} else {
PivotAgg::Sum
}
},
"Date" | "DateTime" => {
if stats.cardinality as f64 / row_count as f64 > 0.9 {
if !quiet {
eprintln!(
"Info: {} column \"{value_col}\" has high cardinality, using First",
stats.r#type
);
if high_cardinality_pivot || high_cardinality_index {
if ordered_pivot && ordered_index {
if !quiet {
eprintln!(
"Info: Ordered temporal data with high cardinality, using Last"
);
}
PivotAgg::Last
} else {
if !quiet {
eprintln!(
"Info: High cardinality detected, using First for {} column",
stats.r#type
);
}
PivotAgg::First
}
PivotAgg::First
} else {
if !quiet {
eprintln!(
"Info: \"{value_col}\" is a {} column, using Count",
stats.r#type
);
eprintln!("Info: Using Count for {} column", stats.r#type);
}
PivotAgg::Count
}
Expand All @@ -284,14 +388,19 @@ fn suggest_agg_function(args: &Args, value_cols: &[String]) -> CliResult<Option<
eprintln!("Info: \"{value_col}\" contains all unique values, using First");
}
PivotAgg::First
} else if stats.cardinality as f64 / row_count as f64 > 0.5 {
} else if stats.sparsity > Some(0.5) {
if !quiet {
eprintln!("Info: Sparse data detected, using Count");
}
PivotAgg::Count
} else if high_cardinality_pivot || high_cardinality_index {
if !quiet {
eprintln!("Info: \"{value_col}\" has high cardinality, using Count");
eprintln!("Info: High cardinality detected, using Count");
}
PivotAgg::Count
} else {
if !quiet {
eprintln!("Info: \"{value_col}\" is a String column, using Count");
eprintln!("Info: Using Count for String column");
}
PivotAgg::Count
}
Expand Down Expand Up @@ -363,7 +472,12 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
"smart" => {
if let Some(value_cols) = &value_cols {
// Try to suggest an appropriate aggregation function
if let Some(suggested_agg) = suggest_agg_function(&args, value_cols)? {
if let Some(suggested_agg) = suggest_agg_function(
&args,
&on_cols,
index_cols.as_deref(),
value_cols,
)? {
suggested_agg
} else {
// fallback to first, which always works
Expand Down
12 changes: 7 additions & 5 deletions tests/test_pivotp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,16 @@ pivotp_test!(
wrk.assert_success(&mut cmd);

let msg = wrk.output_stderr(&mut cmd);
let expected_msg = "Pivot on-column cardinality:\n product: 2\n(2, 3)\n";
let expected_msg = "Info: High variability in values (CV > 1), using Median for more \
robust central tendency\nPivot on-column cardinality:\n product: \
2\n(2, 3)\n";
assert_eq!(msg, expected_msg);

let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["date", "A", "B"],
svec!["2023-01-01", "300", "150"],
svec!["2023-01-02", "300", "600"],
svec!["2023-01-01", "150.0", "150.0"],
svec!["2023-01-02", "300.0", "300.0"],
];
assert_eq!(got, expected);
}
Expand All @@ -604,8 +606,8 @@ pivotp_test!(
let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["date", "A", "B"],
svec!["2023-01-01", "300", "150"],
svec!["2023-01-02", "300", "600"],
svec!["2023-01-01", "150.0", "150.0"],
svec!["2023-01-02", "300.0", "300.0"],
];
assert_eq!(got, expected);
}
Expand Down

0 comments on commit 241b50d

Please sign in to comment.