From 610b102adb510375ac98406a4eef245149f26c1f Mon Sep 17 00:00:00 2001 From: Aaron Meyer Date: Wed, 28 Feb 2024 06:10:58 -0800 Subject: [PATCH] Commit --- lineage/Analyze.py | 2 +- lineage/CellVar.py | 23 ------- lineage/HMM/E_step.py | 5 +- lineage/LineageTree.py | 19 ++++-- lineage/figures/figure18.py | 4 +- lineage/states/StateDistributionGaPhs.py | 5 +- lineage/states/StateDistributionGamma.py | 5 +- lineage/states/stateCommon.py | 8 +-- lineage/tests/test_CellVar.py | 84 ------------------------ lineage/tests/test_StateDistribution.py | 8 +-- 10 files changed, 30 insertions(+), 133 deletions(-) delete mode 100644 lineage/tests/test_CellVar.py diff --git a/lineage/Analyze.py b/lineage/Analyze.py index d9ebc2c57..33571414e 100644 --- a/lineage/Analyze.py +++ b/lineage/Analyze.py @@ -191,7 +191,7 @@ def Results(tHMMobj: tHMM, LL: float) -> dict[str, Any]: results_dict["total_number_of_lineages"] = len(tHMMobj.X) results_dict["LL"] = LL results_dict["total_number_of_cells"] = sum( - [len(lineage.output_lineage) for lineage in tHMMobj.X] + [len(lineage) for lineage in tHMMobj.X] ) true_states_by_lineage = [ diff --git a/lineage/CellVar.py b/lineage/CellVar.py index 9bb0cd63d..7f8c64e8b 100644 --- a/lineage/CellVar.py +++ b/lineage/CellVar.py @@ -38,23 +38,6 @@ def __init__(self, parent: Optional["CellVar"], state: Optional[int] = None): self.right = None self.obs = None - def divide(self, T: np.ndarray, rng=None): - """ - Member function that performs division of a cell. - Equivalent to adding another timestep in a Markov process. - :param T: The array containing the likelihood of a cell switching states. - """ - rng = np.random.default_rng(rng) - # Checking that the inputs are of the right shape - assert T.shape[0] == T.shape[1] - - # roll a loaded die according to the row in the transtion matrix - left_state, right_state = rng.choice(T.shape[0], size=2, p=T[self.state, :]) - self.left = CellVar(state=left_state, parent=self) - self.right = CellVar(state=right_state, parent=self) - - return self.left, self.right - def isLeafBecauseTerminal(self) -> bool: """ Returns true when a cell is a leaf with no children. @@ -79,12 +62,6 @@ def isLeaf(self) -> bool: # otherwise, it itself is observed and at least one of its daughters is observed return False - def isRootParent(self) -> bool: - """ - Returns true if this cell is the first cell in a lineage. - """ - return self.parent is None - @dataclass(init=True, repr=True, eq=True, order=True) class Time: diff --git a/lineage/HMM/E_step.py b/lineage/HMM/E_step.py index 2128bd20f..14b4f8b05 100644 --- a/lineage/HMM/E_step.py +++ b/lineage/HMM/E_step.py @@ -138,10 +138,7 @@ def get_beta( ) # MSD of the respective lineage ELMSD = EL * MSD - cIDXs = np.arange(first_leaf) - cIDXs = np.flip(cIDXs) - - for pii in cIDXs: + for pii in range(first_leaf - 1, -1, -1): ch_ii = np.array([pii * 2 + 1, pii * 2 + 2]) ratt = (beta[ch_ii, :] / MSD_array[ch_ii, :]) @ T.T fac1 = np.prod(ratt, axis=0) * ELMSD[pii, :] diff --git a/lineage/LineageTree.py b/lineage/LineageTree.py index 7c683dcf1..1fd8336c4 100644 --- a/lineage/LineageTree.py +++ b/lineage/LineageTree.py @@ -70,14 +70,19 @@ def rand_init( ) # roll the dice and yield the state for the first cell first_cell = CellVar(parent=None, state=first_state) # create first cell full_lineage = [first_cell] # instantiate lineage with first cell + pIDX = 0 - for cell in full_lineage: # letting the first cell proliferate - if cell.isLeaf(): # if the cell has no daughters... - # make daughters by dividing and assigning states - full_lineage.extend(cell.divide(T, rng=rng)) + # fill in the cells + while len(full_lineage) < desired_num_cells: # letting the first cell proliferate + parent = full_lineage[pIDX] + states = rng.choice(T.shape[0], size=2, p=T[parent.state, :]) - if len(full_lineage) >= desired_num_cells: - break + full_lineage.append(CellVar(parent=parent, state=states[0])) + full_lineage.append(CellVar(parent=parent, state=states[1])) + parent.left = full_lineage[-2] + parent.right = full_lineage[-1] + + pIDX += 1 # Assign observations for i_state in range(pi.size): @@ -140,7 +145,7 @@ def get_Emission_Likelihoods(X: list[LineageTree], E: list) -> list: EL = [] ii = 0 for lineageObj in X: # for each lineage in our Population - nl = len(lineageObj.output_lineage) # getting the lineage length + nl = len(lineageObj) # getting the lineage length EL.append(ELstack[ii : (ii + nl), :]) # append the EL_array for each lineage ii += nl diff --git a/lineage/figures/figure18.py b/lineage/figures/figure18.py index edda64e25..0f44b985c 100644 --- a/lineage/figures/figure18.py +++ b/lineage/figures/figure18.py @@ -49,7 +49,7 @@ tmp_lineage = LineageTree.rand_init( pi, T, E, desired_num_cells, censor_condition=3, desired_experiment_time=96 ) - while len(tmp_lineage.output_lineage) < 3: + while len(tmp_lineage) < 3: tmp_lineage = LineageTree.rand_init( pi, T, @@ -59,7 +59,7 @@ desired_experiment_time=96, ) population.append(tmp_lineage) - nn += len(tmp_lineage.output_lineage) + nn += len(tmp_lineage) # Adding populations into a holder for analysing list_of_populations.append(population) diff --git a/lineage/states/StateDistributionGaPhs.py b/lineage/states/StateDistributionGaPhs.py index 41ffa36ba..48c272173 100644 --- a/lineage/states/StateDistributionGaPhs.py +++ b/lineage/states/StateDistributionGaPhs.py @@ -87,8 +87,9 @@ def assign_times(self, full_lineage: list): This is used in the creation of LineageTrees """ # traversing the cells by generation - for cell in full_lineage: - if cell.isRootParent(): + for ii, cell in enumerate(full_lineage): + # handle root separately + if ii == 0: cell.time = Time(0, cell.obs[2] + cell.obs[3]) cell.time.transition_time = 0 + cell.obs[2] else: diff --git a/lineage/states/StateDistributionGamma.py b/lineage/states/StateDistributionGamma.py index a77421de5..2b304d959 100644 --- a/lineage/states/StateDistributionGamma.py +++ b/lineage/states/StateDistributionGamma.py @@ -135,8 +135,9 @@ def assign_times(self, full_lineage: list[CellVar]): This is used in the creation of LineageTrees. """ # traversing the cells by generation - for cell in full_lineage: - if cell.isRootParent(): + for ii, cell in enumerate(full_lineage): + # if root + if ii == 0: cell.time = Time(0, cell.obs[1]) else: cell.time = Time( diff --git a/lineage/states/stateCommon.py b/lineage/states/stateCommon.py index b5ba400af..aa9d1bb70 100644 --- a/lineage/states/stateCommon.py +++ b/lineage/states/stateCommon.py @@ -14,10 +14,9 @@ def basic_censor(cells: list): """ Censors a cell if the cell's parent is censored. """ - for cell in cells: - if not cell.isRootParent(): - if not cell.parent.observed: - cell.observed = False + for cell in cells[1:]: + if not cell.parent.observed: + cell.observed = False def bern_estimator(bern_obs: np.ndarray, gammas: np.ndarray): @@ -114,6 +113,7 @@ def gamma_estimator( method="SLSQP", constraints=linc, ) + print(res.nit) assert res.success return np.exp(res.x) diff --git a/lineage/tests/test_CellVar.py b/lineage/tests/test_CellVar.py deleted file mode 100644 index 70ad07a49..000000000 --- a/lineage/tests/test_CellVar.py +++ /dev/null @@ -1,84 +0,0 @@ -""" Unit test file. """ - -import unittest -import numpy as np -from ..CellVar import CellVar as c - - -# pylint: disable=protected-access - - -class TestModel(unittest.TestCase): - """ - Unit test class for the cell class. - """ - - def test_cellVar(self): - """ - Make sure cell state assignment is correct. - """ - left_state = 0 - right_state = 1 - - cell_left = c(state=left_state, parent=None) - cell_right = c(state=right_state, parent=None) - - self.assertTrue(cell_left.state == 0) - self.assertTrue(cell_right.state == 1) - - def test_cell_divide(self): - """ - Tests the division of the cells. - """ - T = np.array([[1.0, 0.0], [0.0, 1.0]]) - - parent_state = 1 - cell = c(state=parent_state, parent=None) - left_cell, right_cell = cell.divide(T) - # the probability of switching states is 0 - self.assertTrue(left_cell.state == 1) - self.assertTrue(right_cell.state == 1) - self.assertTrue(right_cell.parent is cell and left_cell.parent is cell) - self.assertTrue(cell.left is left_cell and cell.right is right_cell) - self.assertTrue(not cell.parent) - self.assertTrue(cell.gen == 1) - self.assertTrue(left_cell.gen == 2 and right_cell.gen == 2) - - parent_state = 0 - cell = c(state=parent_state, parent=None) - left_cell, right_cell = cell.divide(T) - # the probability of switching states is 0 - self.assertTrue(left_cell.state == 0) - self.assertTrue(right_cell.state == 0) - self.assertTrue(right_cell.parent is cell and left_cell.parent is cell) - self.assertTrue(cell.left is left_cell and cell.right is right_cell) - self.assertTrue(not cell.parent) - self.assertTrue(cell.gen == 1) - self.assertTrue(left_cell.gen == 2 and right_cell.gen == 2) - - def test_isRootParent(self): - """ - Tests whether the correct root parent asserts work. - """ - T = np.array([[1.0, 0.0], [0.0, 1.0]]) - - parent_state = 1 - cell = c(state=parent_state, parent=None) - left_cell, right_cell = cell.divide(T) - self.assertTrue(cell.isRootParent()) - self.assertFalse(left_cell.isRootParent() and right_cell.isRootParent()) - - def test_isLeafBecauseTerminal(self): - """ - Tests whether the leaf cells are correctly checked. - """ - T = np.array([[1.0, 0.0], [0.0, 1.0]]) - - parent_state = 1 - cell = c(state=parent_state, parent=None) - self.assertTrue(cell.isLeafBecauseTerminal()) - left_cell, right_cell = cell.divide(T) - self.assertFalse(cell.isLeafBecauseTerminal()) - self.assertTrue( - left_cell.isLeafBecauseTerminal() and right_cell.isLeafBecauseTerminal() - ) diff --git a/lineage/tests/test_StateDistribution.py b/lineage/tests/test_StateDistribution.py index 70ef91a99..86f5b0796 100644 --- a/lineage/tests/test_StateDistribution.py +++ b/lineage/tests/test_StateDistribution.py @@ -114,10 +114,10 @@ def test_censor(self): as expected. """ for lin in self.population: - for cell in lin.output_lineage: - if not cell.isRootParent: - if not cell.parent.observed: - self.assertFalse(cell.observed) + # Skip root parent + for cell in lin.output_lineage[1:]: + if not cell.parent.observed: + self.assertFalse(cell.observed) @pytest.mark.parametrize("dist", [StateDistribution, StateDistPhase])