Skip to content

Commit

Permalink
Huge changes as I tried to make tree moves use an iterator, broke eve…
Browse files Browse the repository at this point in the history
…rything, fixed it again finally
  • Loading branch information
jhellewell14 committed Dec 16, 2024
1 parent f1fdbdf commit 702ad73
Show file tree
Hide file tree
Showing 10 changed files with 946 additions and 878 deletions.
69 changes: 34 additions & 35 deletions src/branchlength.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
use crate::{rate_matrix::RateMatrix, treestate::TreeMove};
use rand::prelude::Distribution;
// use rand::distributions::Normal;
use statrs::distribution::Normal;
use crate::topology::Topology;
use crate::TreeState;
pub struct BranchMove{
indices: Vec<usize>,
}
// use crate::{rate_matrix::RateMatrix, treestate::TreeMove};
// use rand::prelude::Distribution;
// // use rand::distributions::Normal;
// use statrs::distribution::Normal;
// use crate::topology::Topology;
// use crate::TreeState;
// pub struct BranchMove{
// indices: Vec<usize>,
// }

impl<R: RateMatrix> TreeMove<R> for BranchMove {
fn generate(&self, ts: &crate::treestate::TreeState<R>) -> crate::treestate::TreeState<R> {
let normal = Normal::new(0.0, 1.0).unwrap();
let mut changes: Vec<usize> = Vec::new();
// impl<R: RateMatrix> TreeMove<R> for BranchMove {
// fn generate(&self, ts: &crate::treestate::TreeState<R>) -> crate::treestate::TreeState<R> {
// let normal = Normal::new(0.0, 1.0).unwrap();
// let mut changes: Vec<usize> = Vec::new();

// This is not ideal
let mut nodes = ts.top.nodes.clone();
// // This is not ideal
// let mut nodes = ts.top.nodes.clone();

for i in self.indices.iter() {
let ind = *i;
let mut bl = nodes[ind].get_branchlen();
bl = bl.ln() + normal.sample(&mut rand::thread_rng());
nodes[ind].set_branchlen(bl.exp());
changes.push(ind);
}
// for i in self.indices.iter() {
// let ind = *i;
// let mut bl = nodes[ind].get_branchlen();
// bl = bl.ln() + normal.sample(&mut rand::thread_rng());
// nodes[ind].set_branchlen(bl.exp());
// changes.push(ind);
// }

let new_top = Topology{
nodes: nodes,
tree_vec: ts.top.tree_vec.clone(),
likelihood: ts.top.likelihood,
};
// let new_top = Topology{
// nodes: nodes,
// tree_vec: ts.top.tree_vec.clone(),
// };

TreeState{
top: new_top,
mat: ts.mat,
ll: ts.ll,
changed_nodes: Some(changes),
}
}
}
// TreeState{
// top: new_top,
// mat: ts.mat,
// ll: ts.ll,
// changed_nodes: Some(changes),
// }
// }
// }
120 changes: 65 additions & 55 deletions src/genetic_data.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::os::unix::thread;
use std::thread::current;
use std::collections::HashMap;
use needletail::parse_fastx_file;
use crate::topology::Topology;
use ndarray::s;
use logaddexp::LogAddExp;
use ndarray::s;
use needletail::parse_fastx_file;
use rand::{thread_rng, Rng};
use std::collections::HashMap;
use std::os::unix::thread;
use std::thread::current;

const NEGINF: f64 = -f64::INFINITY;
// (A, C, G, T)
Expand Down Expand Up @@ -51,25 +51,28 @@ pub fn count_sequences(filename: &str) -> usize {
let mut n_seqs: usize = 0;
while let Some(_record) = reader.next() {
n_seqs += 1;
};
}
n_seqs
}


pub fn create_genetic_data(filename: &str, topology: &Topology, rate_matrix: &na::Matrix4<f64>) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> {
pub fn create_genetic_data(
filename: &str,
topology: &Topology,
rate_matrix: &na::Matrix4<f64>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> {
// Count number of sequences and their length
let mut n_seqs = 0;
let mut n_bases= 0;
let mut n_bases = 0;
let mut reader = parse_fastx_file(filename).expect("Error parsing file");
while let Some(record) = reader.next() {
let seqrec = record.expect("Invalid record");
n_seqs += 1;
n_bases = seqrec.num_bases();
}
// Create pre-filled array
let mut gen_data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> =
ndarray::Array3::from_elem(((2 * n_seqs) - 1, n_bases, 4), -99.0);
// println!("Assigning data for {} leaves and {} total nodes", n_seqs, (2 * n_seqs) + 1);
let mut gen_data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> =
ndarray::Array3::from_elem((2 * n_seqs - 1, n_bases, 4), -99.0);
// println!("Assigning data for {} leaves and {} total nodes", n_seqs, (2 * n_seqs) + 1);

let mut reader2 = parse_fastx_file(filename).expect("Error parsing file");
let mut seq_i = 0;
Expand All @@ -80,17 +83,18 @@ pub fn create_genetic_data(filename: &str, topology: &Topology, rate_matrix: &na
for j in 0..4 {
gen_data[[seq_i, loc_i, j]] = *cur.get(j).unwrap();
}
}
}
seq_i += 1;
}

create_internal_data(gen_data, topology, rate_matrix)
}


pub fn create_internal_data(mut data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>,
topology: &Topology, rate_matrix: &na::Matrix4<f64>) ->
ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> {
pub fn create_internal_data(
mut data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>,
topology: &Topology,
rate_matrix: &na::Matrix4<f64>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> {
// Iterate over internal nodes postorder
let nodes = topology.postorder_notips(topology.get_root());

Expand All @@ -99,10 +103,12 @@ pub fn create_internal_data(mut data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>
// Calculate node likelihood
let lchild = node.get_lchild().unwrap();
let rchild = node.get_rchild().unwrap();
let node_ll = node_likelihood(slice_data(lchild, &data),
slice_data(rchild, &data),
&matrix_exp(rate_matrix, topology.nodes[lchild].get_branchlen()),
&matrix_exp(rate_matrix, topology.nodes[lchild].get_branchlen()));
let node_ll = node_likelihood(
data.slice(s![lchild, .., ..]),
data.slice(s![rchild, .., ..]),
&matrix_exp(rate_matrix, topology.nodes[lchild].get_branchlen()),
&matrix_exp(rate_matrix, topology.nodes[lchild].get_branchlen()),
);
// let node_ll = node_likelihood(node.get_lchild().unwrap(), node.get_rchild().unwrap(), &gen_data, topology, rate_matrix);
// let node_ll = node_likelihood(i, &gen_data, topology, rate_matrix);
// Add to genetic data array
Expand All @@ -112,13 +118,16 @@ pub fn create_internal_data(mut data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>
data
}

pub fn create_dummy_gendata(n_bases: usize, topology: &Topology, rate_matrix: &na::Matrix4<f64>) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> {

pub fn create_dummy_gendata(
n_bases: usize,
topology: &Topology,
rate_matrix: &na::Matrix4<f64>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> {
let n_seqs = topology.count_leaves();

let mut gen_data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> =
let mut gen_data: ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>> =
ndarray::Array3::from_elem(((2 * n_seqs) + 1, n_bases, 4), 0.0);

let mut rng = thread_rng();

for i in 0..n_seqs {
Expand All @@ -130,42 +139,44 @@ pub fn create_dummy_gendata(n_bases: usize, topology: &Topology, rate_matrix: &n

create_internal_data(gen_data, topology, rate_matrix)
}

pub fn child_likelihood_i(i: usize, ll: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 1]>>, p: &na::Matrix4<f64>) -> f64 {
p
.row(i)
.iter()
.zip(ll.iter())
.map(|(a, b)| a.ln() + *b)
.reduce(|a, b| a.ln_add_exp(b))
.unwrap()

pub fn child_likelihood_i(
i: usize,
ll: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 1]>>,
p: &na::Matrix4<f64>,
) -> f64 {
p.row(i)
.iter()
.zip(ll.iter())
.map(|(a, b)| a.ln() + *b)
.reduce(|a, b| a.ln_add_exp(b))
.unwrap()
}

pub fn matrix_exp(rate_matrix: &na::Matrix4<f64>, branch_len: f64) -> na::Matrix4<f64> {
na::Matrix::exp(&(rate_matrix * branch_len))
}

pub fn slice_data(index: usize, data: &ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) ->
ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>> {
data.slice(s![index, .., ..])
}

pub fn node_likelihood(seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>,
pub fn node_likelihood(
seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>,
seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>,
matrixl: &na::Matrix4<f64>,
matrixr: &na::Matrix4<f64>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> {

let out = ndarray::Array2::from_shape_fn((seql.dim().0, 4), |(i, j)|
child_likelihood_i(j, seql.slice(s![i, ..]), matrixl) +
child_likelihood_i(j, seqr.slice(s![i, ..]), matrixr));
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> {
let out = ndarray::Array2::from_shape_fn((seql.dim().0, 4), |(i, j)| {
child_likelihood_i(j, seql.slice(s![i, ..]), matrixl)
+ child_likelihood_i(j, seqr.slice(s![i, ..]), matrixr)
});

out
out
}

pub const BF_DEFAULT: [f64; 4] = [0.25, 0.25, 0.25, 0.25];

pub fn base_freq_logse(muta: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 1]>>, bf: [f64; 4]) -> f64 {
pub fn base_freq_logse(
muta: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 1]>>,
bf: [f64; 4],
) -> f64 {
muta.iter()
.zip(bf.iter())
.fold(0.0, |tot, (muta, bf)| tot + muta.exp() * bf)
Expand All @@ -174,18 +185,17 @@ pub fn base_freq_logse(muta: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray

impl Topology {
pub fn find_changes(&self, other: &Topology) -> Option<Vec<usize>> {
let out: Vec<usize> = self.nodes
.iter()
.zip(other.nodes.iter())
.filter(|(a, b)| a.get_parent().ne(&b.get_parent()))
.map(|(a, b)| a.get_id())
.collect();
let out: Vec<usize> = self
.nodes
.iter()
.zip(other.nodes.iter())
.filter(|(a, b)| a.get_parent().ne(&b.get_parent()))
.map(|(a, _)| a.get_id())
.collect();
if out.is_empty() {
None
} else {
Some(out)
}
}

}

Loading

0 comments on commit 702ad73

Please sign in to comment.