Skip to content

Commit

Permalink
refactor(derive/readlen): move num_samples counting to outer func
Browse files Browse the repository at this point in the history
  • Loading branch information
a-frantz committed Dec 15, 2023
1 parent 65ea501 commit 76305f2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
11 changes: 5 additions & 6 deletions src/derive/command/readlen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,15 @@ pub fn derive(args: DeriveReadlenArgs) -> anyhow::Result<()> {
.and_modify(|e| *e += 1)
.or_insert(1 as u64);

if sample_max > 0 {
samples += 1;
if samples > sample_max {
break;
}
samples += 1;
if sample_max > 0 && samples > sample_max {
break;
}
}

// (2) Derive the consensus read length based on the read lengths gathered.
let result = compute::predict(read_lengths, args.majority_vote_cutoff.unwrap()).unwrap();
let result =
compute::predict(read_lengths, samples, args.majority_vote_cutoff.unwrap()).unwrap();

// (3) Print the output to stdout as JSON (more support for different output
// types may be added in the future, but for now, only JSON).
Expand Down
15 changes: 7 additions & 8 deletions src/derive/readlen/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,25 @@ impl DerivedReadlenResult {
/// resulting [`DerivedReadlenResult`] should be evaluated accordingly.
pub fn predict(
read_lengths: HashMap<u32, u64>,
num_samples: u64,
majority_vote_cutoff: f64,
) -> Result<DerivedReadlenResult, anyhow::Error> {
let mut num_records = 0;
let mut max_count = 0;
let mut max_read_length = 0;

for (read_length, count) in &read_lengths {
num_records += *count;
if *read_length > max_read_length {
max_read_length = *read_length;
max_count = *count;
}
}

if num_records == 0 {
if num_samples <= 0 {
bail!("No read lengths were detected in the file.");
}

let consensus_read_length = max_read_length;
let majority_detected = max_count as f64 / num_records as f64;
let majority_detected = max_count as f64 / num_samples as f64;

// Sort the read lengths by their key for output.
let mut read_lengths: Vec<(u32, u64)> = read_lengths.into_iter().collect();
Expand All @@ -87,14 +86,14 @@ mod tests {
#[test]
fn test_derive_readlen_from_empty_hashmap() {
let read_lengths = HashMap::new();
let result = predict(read_lengths, 0.7);
let result = predict(read_lengths, 0, 0.7);
assert!(result.is_err());
}

#[test]
fn test_derive_readlen_when_all_readlengths_equal() {
let read_lengths = HashMap::from([(100, 10)]);
let result = predict(read_lengths, 1.0).unwrap();
let result = predict(read_lengths, 10, 1.0).unwrap();
assert!(result.succeeded);
assert_eq!(result.consensus_read_length, Some(100));
assert_eq!(result.majority_pct_detected, 100.0);
Expand All @@ -104,7 +103,7 @@ mod tests {
#[test]
fn test_derive_readlen_success_when_not_all_readlengths_equal() {
let read_lengths = HashMap::from([(101, 1000), (100, 5), (99, 5)]);
let result = predict(read_lengths, 0.7).unwrap();
let result = predict(read_lengths, 1010, 0.7).unwrap();
assert!(result.succeeded);
assert_eq!(result.consensus_read_length, Some(101));
assert!(result.majority_pct_detected > 99.0);
Expand All @@ -114,7 +113,7 @@ mod tests {
#[test]
fn test_derive_readlen_fail_when_not_all_readlengths_equal() {
let read_lengths = HashMap::from([(101, 5), (100, 1000), (99, 5)]);
let result = predict(read_lengths, 0.7).unwrap();
let result = predict(read_lengths, 1010, 0.7).unwrap();
assert!(!result.succeeded);
assert_eq!(result.consensus_read_length, None);
assert!(result.majority_pct_detected < 0.7);
Expand Down

0 comments on commit 76305f2

Please sign in to comment.