From 76305f2a22d481578dd3fa53f07c6601fd6b6806 Mon Sep 17 00:00:00 2001 From: Andrew Frantz Date: Fri, 15 Dec 2023 10:17:56 -0500 Subject: [PATCH] refactor(derive/readlen): move num_samples counting to outer func --- src/derive/command/readlen.rs | 11 +++++------ src/derive/readlen/compute.rs | 15 +++++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/derive/command/readlen.rs b/src/derive/command/readlen.rs index 9449b03..6d1c944 100644 --- a/src/derive/command/readlen.rs +++ b/src/derive/command/readlen.rs @@ -70,16 +70,15 @@ pub fn derive(args: DeriveReadlenArgs) -> anyhow::Result<()> { .and_modify(|e| *e += 1) .or_insert(1 as u64); - if sample_max > 0 { - samples += 1; - if samples > sample_max { - break; - } + samples += 1; + if sample_max > 0 && samples > sample_max { + break; } } // (2) Derive the consensus read length based on the read lengths gathered. - let result = compute::predict(read_lengths, args.majority_vote_cutoff.unwrap()).unwrap(); + let result = + compute::predict(read_lengths, samples, args.majority_vote_cutoff.unwrap()).unwrap(); // (3) Print the output to stdout as JSON (more support for different output // types may be added in the future, but for now, only JSON). diff --git a/src/derive/readlen/compute.rs b/src/derive/readlen/compute.rs index 679e8d2..72c1c59 100644 --- a/src/derive/readlen/compute.rs +++ b/src/derive/readlen/compute.rs @@ -44,26 +44,25 @@ impl DerivedReadlenResult { /// resulting [`DerivedReadlenResult`] should be evaluated accordingly. pub fn predict( read_lengths: HashMap, + num_samples: u64, majority_vote_cutoff: f64, ) -> Result { - let mut num_records = 0; let mut max_count = 0; let mut max_read_length = 0; for (read_length, count) in &read_lengths { - num_records += *count; if *read_length > max_read_length { max_read_length = *read_length; max_count = *count; } } - if num_records == 0 { + if num_samples <= 0 { bail!("No read lengths were detected in the file."); } let consensus_read_length = max_read_length; - let majority_detected = max_count as f64 / num_records as f64; + let majority_detected = max_count as f64 / num_samples as f64; // Sort the read lengths by their key for output. let mut read_lengths: Vec<(u32, u64)> = read_lengths.into_iter().collect(); @@ -87,14 +86,14 @@ mod tests { #[test] fn test_derive_readlen_from_empty_hashmap() { let read_lengths = HashMap::new(); - let result = predict(read_lengths, 0.7); + let result = predict(read_lengths, 0, 0.7); assert!(result.is_err()); } #[test] fn test_derive_readlen_when_all_readlengths_equal() { let read_lengths = HashMap::from([(100, 10)]); - let result = predict(read_lengths, 1.0).unwrap(); + let result = predict(read_lengths, 10, 1.0).unwrap(); assert!(result.succeeded); assert_eq!(result.consensus_read_length, Some(100)); assert_eq!(result.majority_pct_detected, 100.0); @@ -104,7 +103,7 @@ mod tests { #[test] fn test_derive_readlen_success_when_not_all_readlengths_equal() { let read_lengths = HashMap::from([(101, 1000), (100, 5), (99, 5)]); - let result = predict(read_lengths, 0.7).unwrap(); + let result = predict(read_lengths, 1010, 0.7).unwrap(); assert!(result.succeeded); assert_eq!(result.consensus_read_length, Some(101)); assert!(result.majority_pct_detected > 99.0); @@ -114,7 +113,7 @@ mod tests { #[test] fn test_derive_readlen_fail_when_not_all_readlengths_equal() { let read_lengths = HashMap::from([(101, 5), (100, 1000), (99, 5)]); - let result = predict(read_lengths, 0.7).unwrap(); + let result = predict(read_lengths, 1010, 0.7).unwrap(); assert!(!result.succeeded); assert_eq!(result.consensus_read_length, None); assert!(result.majority_pct_detected < 0.7);