Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding additional inner-loop-break-condition when the last 3 consecut… #15

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions man/aldknni.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 15 additions & 5 deletions src/rust/src/aldknni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,13 @@ impl GenotypesAndPhenotypes {
n_non_missing
};
// Select randomly non-missing pools to impute and estimate imputation error from
let mut rng = rand::thread_rng();
let mut rng: rand::rngs::ThreadRng = rand::thread_rng();
let idx_random_pools: Vec<usize> = (0..vec_q.len()).filter(|&idx| !vec_q[idx].is_nan()).choose_multiple(&mut rng, n_reps);
let mut optimum_mae = 1.0;
let mut optimum_min_loci_corr = vec_min_loci_corr[0];
let mut optimum_max_pool_dist = vec_max_pool_dist[0];
// Optimum mae, and parameters
let mut optimum_mae: f64 = 1.0;
let mut optimum_min_loci_corr: f64 = vec_min_loci_corr[0];
let mut optimum_max_pool_dist: f64 = vec_max_pool_dist[0];
let mut recent_3_maes: Vec<f64> = vec![0.0, 0.5, 1.0];
// Find the optimal min_loci_corr and max_pool_dist which minimise imputation error (MAE: mean absolute error)
// Across minimum loci correlation thresholds
for min_loci_corr in vec_min_loci_corr.iter() {
Expand Down Expand Up @@ -347,7 +349,15 @@ impl GenotypesAndPhenotypes {
).abs();
}
mae /= n_reps as f64;
if (mae <= f64::EPSILON) | (mae > optimum_mae) {
// The most recent 3 MAEs
recent_3_maes[2] = recent_3_maes[1];
recent_3_maes[1] = recent_3_maes[0];
recent_3_maes[0] = mae;
// Inner loop break conditions
if (mae <= f64::EPSILON) |
(mae > optimum_mae) |
(((recent_3_maes[0]-recent_3_maes[1]).abs() <= f64::EPSILON) &
((recent_3_maes[0]-recent_3_maes[2]).abs() <= f64::EPSILON)) {
break;
}
if mae < optimum_mae {
Expand Down
6 changes: 3 additions & 3 deletions src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn impute(
// 2) sync - intermediate richness and most preferred (*.sync)
// 3) geno - least detailed - tab-delimited: chr,pos,allele,sample-1,sample-2,some-name-@#@#$%^&*(+)}:<'?"-with-a-bunch-of-asci-characters,... (*.txt)
let extension_name: &str = fname
.split(".")
.split('.')
.collect::<Vec<&str>>()
.last()
.expect("Error extracting the last character of the input filename in impute().");
Expand Down Expand Up @@ -216,7 +216,7 @@ fn impute(
vec_header
};
let vec_header: Vec<&str> = if vec_header.len() == 1 {
header.split(";").collect()
header.split(';').collect()
} else {
vec_header
};
Expand Down Expand Up @@ -297,7 +297,7 @@ fn impute(
} else {
fname_out_prefix.to_owned() + "-" + &rand_id + "-IMPUTED.csv"
};
let _ = if &imputation_method == &"mean".to_owned() {
let _ = if imputation_method == *"mean" {
println!("###################################################################################################");
println!("mvi: mean value imputation");
println!("###################################################################################################");
Expand Down
4 changes: 2 additions & 2 deletions src/rust/src/vcf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ impl ChunkyReadAnalyseWrite<VcfLine, fn(&mut VcfLine, &FilterStats) -> Option<St
// Find the positions where to split the file into n_threads pieces
let chunks = find_file_splits(&fname, n_threads).expect("Error calling find_file_splits() within the read_analyse_write() method for FileVcf struct.");
let n_threads = chunks.len() - 1;
let outname_ndigits = chunks[n_threads].to_string().len();
let n_digits = chunks[n_threads].to_string().len();
println!("Chunks: {:?}", chunks);
// Tuple arguments of vcf2sync_chunks
// Instantiate thread object for parallel execution
Expand All @@ -428,7 +428,7 @@ impl ChunkyReadAnalyseWrite<VcfLine, fn(&mut VcfLine, &FilterStats) -> Option<St
let self_clone = self.clone();
let start = chunks[i];
let end = chunks[i + 1];
let outname_ndigits = outname_ndigits;
let outname_ndigits = n_digits;
let filter_stats = filter_stats.clone();
let thread_ouputs_clone = thread_ouputs.clone(); // Mutated within the current thread worker
let thread = std::thread::spawn(move || {
Expand Down
20 changes: 10 additions & 10 deletions tests/tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ test_that(
)

test_that(
"aldknni_optim", {
print("aldknni_optim:")
vcf = fn_extract_missing(aldknni(fname=fname_vcf))
sync = fn_extract_missing(aldknni(fname=fname_sync))
csv = fn_extract_missing(aldknni(fname=fname_csv)) - c(0, 0, 0, 0, 1, 0)
"aldknni_fixed", {
print("aldknni_fixed:")
vcf = fn_extract_missing(aldknni(fname=fname_vcf, min_loci_corr=0.9, max_pool_dist=0.1))
sync = fn_extract_missing(aldknni(fname=fname_sync, min_loci_corr=0.9, max_pool_dist=0.1))
csv = fn_extract_missing(aldknni(fname=fname_csv, min_loci_corr=0.9, max_pool_dist=0.1)) - c(0, 0, 0, 0, 1, 0)
expect_equal(vcf, sync, tolerance=0.1)
expect_equal(vcf, csv, tolerance=0.1)
}
)

test_that(
"aldknni_fixed", {
print("aldknni_fixed:")
vcf = fn_extract_missing(aldknni(fname=fname_vcf, min_loci_corr=0.9, max_pool_dist=0.1))
sync = fn_extract_missing(aldknni(fname=fname_sync, min_loci_corr=0.9, max_pool_dist=0.1))
csv = fn_extract_missing(aldknni(fname=fname_csv, min_loci_corr=0.9, max_pool_dist=0.1)) - c(0, 0, 0, 0, 1, 0)
"aldknni_optim", {
print("aldknni_optim:")
vcf = fn_extract_missing(aldknni(fname=fname_vcf))
sync = fn_extract_missing(aldknni(fname=fname_sync))
csv = fn_extract_missing(aldknni(fname=fname_csv)) - c(0, 0, 0, 0, 1, 0)
expect_equal(vcf, sync, tolerance=0.1)
expect_equal(vcf, csv, tolerance=0.1)
}
Expand Down
Loading