Skip to content

Commit

Permalink
Integrating new TreeState struct and TreeMove trait into code
Browse files Browse the repository at this point in the history
  • Loading branch information
jhellewell14 committed Nov 29, 2024
1 parent 49e0bb4 commit d1edf0b
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 170 deletions.
178 changes: 19 additions & 159 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ mod genetic_data;
mod moves;
mod state_data;
mod branchlength;
mod treestate;

use rate_matrix::RateMatrix;
use state_data::create_dummy_statedata;
use topology::Topology;
use treestate::TreeState;

use crate::newick_to_vec::*;
extern crate nalgebra as na;
Expand All @@ -20,11 +22,6 @@ use std::env::args;
use std::time::Instant;
use crate::genetic_data::*;
use crate::moves::*;
use rand::Rng;
use crate::iterators::Handedness;
// use crate::rate_matrix::update_matrix;
use ndarray::s;
use std::collections::HashMap;

pub fn main() {
let args = cli_args();
Expand All @@ -38,161 +35,23 @@ pub fn main() {

let mut t: Topology = Topology::from_vec(&v);

let p = &rate_matrix::GTR::default();
let p = rate_matrix::GTR::default();
let mut gen_data = create_genetic_data(&args.alignment, &t, &p.get_matrix());

println!("{:?}", likelihood(&t, &gen_data));
println!("{:?}", t.get_newick());
println!("{:?}", t.tree_vec);

let mge_mat = na::Matrix2::new(0.4, 0.6, 0.6, 0.4);
let mut st = create_dummy_statedata(1, &t, &mge_mat);

let nodes: Vec<usize> = t.postorder(t.get_root()).map(|n| n.get_id()).collect();
for i in nodes {
let old_len = t.nodes[i].get_branchlen();
t.nodes[i].set_branchlen(old_len + 1.0);
let mut ts = TreeState{
top: t,
mat: p,
ll: None,
changed_nodes: None,
};

pub struct TreeState<R: RateMatrix>{
top: Topology,
mat: R,
ll: Option<f64>,
changed_nodes: Option<Vec<usize>>,
}

pub trait TreeMove<R: RateMatrix> {
fn generate(&self, ts: &TreeState<R>) -> TreeState<R>;
}

pub struct MatrixMove {}

impl<R: RateMatrix> TreeMove<R> for MatrixMove {
fn generate(&self, ts: &TreeState<R>) -> TreeState<R> {
let rm = ts.mat.matrix_move();
let changes: Vec<usize> = ts.top.postorder_notips(ts.top.get_root()).map(|n| n.get_id()).collect();
// This is not ideal
let new_top = Topology{
nodes: ts.top.nodes.clone(),
tree_vec: ts.top.tree_vec.clone(),
likelihood: ts.top.likelihood,
};

TreeState{
top: new_top,
mat: rm,
ll: ts.ll,
changed_nodes: Some(changes),
}
}
}

impl<R:RateMatrix> TreeMove<R> for ExactMove {
fn generate(&self, ts: &TreeState<R>) -> TreeState<R> {
let new_topology = Topology::from_vec(&self.target_vector);
let changes: Option<Vec<usize>> = ts.top.find_changes(&new_topology);
let mat = ts.mat;
TreeState{
top: new_topology,
mat: mat,
ll: ts.ll,
changed_nodes: changes,
}
}
}

impl<R: RateMatrix> TreeState<R> {
println!("{:?}", ts.likelihood(&gen_data));
println!("{:?}", ts.top.get_newick());
println!("{:?}", ts.top.tree_vec);

pub fn likelihood(&self, gen_data: &ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) -> f64 {
let root_likelihood = gen_data.slice(s![self.top.get_root().get_id(), .., .. ]);

root_likelihood
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT))
}


pub fn apply_move<T: TreeMove<R>>(mut self,
move_fn: T,
accept_fn: fn(&f64, &f64) -> bool,
gen_data: &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) -> TreeState<R> {

if self.ll.is_none() {
self.ll = Some(self.likelihood(gen_data));
}
let old_ll = self.ll.unwrap();

let rate_mat = self.mat.get_matrix();
let new_ts = move_fn.generate(&self);

// If move did nothing, return old TreeState
if new_ts.changed_nodes.is_none() {
return self
}

// Do minimal likelihood updates (and push new values into HashMap temporarily)
let nodes = new_ts.top.changes_iter(new_ts.changed_nodes.unwrap());
let mut temp_likelihoods: HashMap<usize, ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>> = HashMap::new();

for node in nodes {
// check if in HM
let lchild = node.get_lchild().unwrap();
let rchild = node.get_rchild().unwrap();
let seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;
let seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;

match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) {
(true, true) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(true, false) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = slice_data(rchild, &gen_data);
},
(false, true) => {
seql = slice_data(lchild, &gen_data);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(false, false) => {
seql = slice_data(lchild, &gen_data);
seqr = slice_data(rchild, &gen_data);
},
};

let node_ll = node_likelihood(seql, seqr,
&matrix_exp(&rate_mat, new_ts.top.nodes[lchild].get_branchlen()),
&matrix_exp(&rate_mat, new_ts.top.nodes[rchild].get_branchlen()));

temp_likelihoods.insert(node.get_id(), node_ll);
}

// Calculate whole new topology likelihood at root
let new_ll = temp_likelihoods
.get(&new_ts.top.get_root().get_id())
.unwrap()
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

// Likelihood decision rule
if accept_fn(&old_ll, &new_ll) {
// Drain hashmap into gen_data
for (i, ll_data) in temp_likelihoods.drain() {
gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
}
// Update Topology
self.top.nodes = new_ts.top.nodes;
self.top.tree_vec = new_ts.top.tree_vec;
self.mat = new_ts.mat;
self.ll = Some(new_ll);
};

self
let mge_mat = na::Matrix2::new(0.4, 0.6, 0.6, 0.4);
// let mut st = create_dummy_statedata(1, &t, &mge_mat);

}
}
// let mut pp = rate_matrix::GTR::default();
// println!("{:?}", pp.get_matrix());
// update_matrix(&mut t, always_accept, &mut gen_data, &mut pp);
Expand All @@ -204,15 +63,16 @@ pub fn main() {
let start = Instant::now();
for i in 0..0 {
println!{"Step {}", i};
// let new_v = random_vector(27);
// let mv = ExactMove{target_vector: new_v};
let new_v = random_vector(27);
let mv = ExactMove{target_vector: new_v};
// let mv = ChildSwap{};
let mv = PeturbVec{n: 1};
t.apply_move(mv, hillclimb_accept, &mut gen_data, &mut p.get_matrix());
// let mv = PeturbVec{n: 1};
ts.apply_move(mv, always_accept, &mut gen_data);
// t.apply_move(mv, hillclimb_accept, &mut gen_data, &mut p.get_matrix());

}
let end = Instant::now();
println!("New likelihood: {:?}", likelihood(&t, &gen_data));
println!("New likelihood: {:?}", ts.likelihood(&gen_data));
eprintln!("Done in {}s", end.duration_since(start).as_secs());
eprintln!("Done in {}ms", end.duration_since(start).as_millis());
}
Expand Down
34 changes: 23 additions & 11 deletions src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#[cfg(test)]
mod tests {
use crate::{likelihood, newick_to_vec::*};
use crate::newick_to_vec::{newick_to_vector, random_vector};
use crate::rate_matrix::{RateMatrix, GTR};
use crate::topology::Topology;
use crate::moves::ExactMove;
use crate::create_dummy_gendata;
use crate::always_accept;
use crate::treestate::TreeState;

#[test]
fn check_topology_build_manual() {
Expand Down Expand Up @@ -85,20 +86,29 @@ mod tests {
#[test]
fn update_tree() {
let p = GTR::default();
let mut t_1 = Topology::from_vec(&vec![0, 0, 1, 0]);
let t_1 = Topology::from_vec(&vec![0, 0, 1, 0]);

let mut gen_data = create_dummy_gendata(2, &t_1, &p.get_matrix());

let mut ts = TreeState{
top: t_1,
mat: p,
ll: None,
changed_nodes: None,
};

let vecs: Vec<Vec<usize>> = vec![vec![0, 0, 0, 0], vec![0, 0, 1, 0], vec![0, 0, 1, 2], vec![0, 0, 1, 1]];
let n = ts.top.nodes.len();

for vec in vecs {
let t_2 = Topology::from_vec(&vec);
let mv = ExactMove{target_vector: vec};
t_1.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());

for i in 0..t_1.nodes.len() {
assert_eq!(t_1.nodes[i].get_parent(), t_2.nodes[i].get_parent());
assert_eq!(t_1.nodes[i].get_id(), t_2.nodes[i].get_id());
ts.apply_move(mv, always_accept, &mut gen_data);
// t_1.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());

for i in 0..n {
assert_eq!(ts.top.nodes[i].get_parent(), t_2.nodes[i].get_parent());
assert_eq!(ts.top.nodes[i].get_id(), t_2.nodes[i].get_id());
};
}

Expand All @@ -109,16 +119,18 @@ mod tests {
let p = GTR::default();
let mut t = Topology::from_vec(&vec![0, 0, 0, 0]);
let mut gen_data = create_dummy_gendata(5, &t, &p.get_matrix());
let mut ts = TreeState{top: t, mat: p, ll: None, changed_nodes: None};

let old_likelihood = likelihood(&t, &gen_data);
let old_likelihood = ts.likelihood(&gen_data);

let mv = ExactMove{target_vector: vec![0, 0, 0, 1]};
t.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());
ts.apply_move(mv, always_accept, &mut gen_data);
// t.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());

let mv = ExactMove{target_vector: vec![0, 0, 0, 0]};
t.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());
ts.apply_move(mv, always_accept, &mut gen_data);

let new_likelihood = likelihood(&t, &gen_data);
let new_likelihood = ts.likelihood(&gen_data);

assert_eq!(old_likelihood, new_likelihood);
}
Expand Down
Loading

0 comments on commit d1edf0b

Please sign in to comment.