Skip to content

Commit

Permalink
Working
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Mar 4, 2024
1 parent 9418293 commit 06cf0fa
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 14 deletions.
8 changes: 6 additions & 2 deletions lineage/LineageTree.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
""" This file contains the LineageTree class. """

from typing import Sequence
import numpy as np
import numpy.typing as npt
import operator
from .CellVar import CellVar
from .states.StateDistributionGamma import StateDistribution as StA
from .states.StateDistributionGaPhs import StateDistribution as StB


class LineageTree:
Expand All @@ -16,8 +19,9 @@ class LineageTree:
leaves_idx: np.ndarray
output_lineage: list[CellVar]
cell_to_daughters: np.ndarray
E: Sequence[StA | StB]

def __init__(self, list_of_cells: list, E: list):
def __init__(self, list_of_cells: list, E: Sequence[StA | StB]):
self.E = E
# output_lineage must be sorted according to generation
self.output_lineage = sorted(list_of_cells, key=operator.attrgetter("gen"))
Expand All @@ -32,7 +36,7 @@ def rand_init(
cls,
pi: np.ndarray,
T: np.ndarray,
E: list,
E: Sequence[StA | StB],
desired_num_cells: int,
censor_condition=0,
desired_experiment_time=2e12,
Expand Down
6 changes: 3 additions & 3 deletions lineage/states/StateDistributionGaPhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import numpy as np

from .stateCommon import basic_censor
from .stateCommon import basic_censor, StateDistributionClass
from .StateDistributionGamma import StateDistribution as GammaSD
from ..CellVar import Time, CellVar


class StateDistribution:
class StateDistribution(StateDistributionClass):
"""For G1 and G2 separated as observations."""

def __init__(
Expand Down Expand Up @@ -54,7 +54,7 @@ def dof(self):
"""Return the degrees of freedom."""
return self.G1.dof() + self.G2.dof()

def logpdf(self, x: np.ndarray):
def logpdf(self, x: np.ndarray) -> np.ndarray:
"""To calculate the log-likelihood of observations to states."""

G1_LL = self.G1.logpdf(x[:, np.array([0, 2, 4])])
Expand Down
6 changes: 3 additions & 3 deletions lineage/states/StateDistributionGamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import scipy.stats as sp
from typing import Union, Literal

from .stateCommon import gamma_estimator, basic_censor, bern_estimator
from .stateCommon import gamma_estimator, basic_censor, bern_estimator, StateDistributionClass
from ..CellVar import Time, CellVar


class StateDistribution:
class StateDistribution(StateDistributionClass):
"""
StateDistribution for cells with gamma distributed times.
"""
Expand Down Expand Up @@ -136,7 +136,7 @@ def censor_lineage(
censor_condition: int,
full_lineage: list[CellVar],
desired_experiment_time=2e12,
):
) -> list[CellVar]:
"""
This function removes those cells that are intended to be removed.
These cells include the descendants of a cell that has died, or has lived beyonf the experimental end time.
Expand Down
44 changes: 38 additions & 6 deletions lineage/states/stateCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,39 @@
import warnings
from typing import Literal
import numpy as np
from numba import njit
import numpy.typing as npt
from scipy.optimize import minimize, Bounds, LinearConstraint
from scipy.special import gammaincc, gammaln
from ..CellVar import CellVar
from ctypes import CFUNCTYPE, c_double
from numba.extending import get_cython_function_address

arr_type = npt.NDArray[np.float64]


warnings.filterwarnings("ignore", message="Values in x were outside bounds")


class StateDistributionClass:
def rvs(self, size: int, rng=None):
raise NotImplementedError("dist not implemented.")

def dist(self, other) -> float:
raise NotImplementedError("dist not implemented.")

def dof(self) -> int:
raise NotImplementedError("dof not implemented.")

def logpdf(self, x: np.ndarray) -> np.ndarray:
raise NotImplementedError("logpdf not implemented.")

def estimator(self, x: np.ndarray, gammas: np.ndarray):
raise NotImplementedError("estimator not implemented.")

def censor_lineage(self, censor_condition: int, full_lineage: list[CellVar], desired_experiment_time=2e12) -> list[CellVar]:
raise NotImplementedError("censor_lineage not implemented.")


def basic_censor(cells: list):
"""
Censors a cell if the cell's parent is censored.
Expand All @@ -34,6 +57,14 @@ def bern_estimator(bern_obs: np.ndarray, gammas: np.ndarray):
return numerator / denominator


addr = get_cython_function_address("scipy.special.cython_special", "gammaincc")
gammaincc = CFUNCTYPE(c_double, c_double, c_double)(addr)

addr = get_cython_function_address("scipy.special.cython_special", "gammaln")
gammaln = CFUNCTYPE(c_double, c_double)(addr)


@njit
def gamma_LL(
logX: arr_type, gamma_obs: arr_type, time_cen: arr_type, gammas: arr_type, param_idx
):
Expand All @@ -49,10 +80,11 @@ def gamma_LL(
(x[0] - 1.0) * np.log(gobs) - gobs - glnA - logX[param_idx],
)

jidx = time_cen == 0.0
gamP = gammaincc(x[0], gobs[jidx])
gamP = np.maximum(gamP, 1e-35) # Clip if the probability hits exactly 0
outt -= np.sum(gammas[jidx] * np.log(gamP))
for jj, cen in enumerate(time_cen):
if cen == 0:
gamP = gammaincc(x[0], gobs[jj])
gamP = np.maximum(gamP, 1e-35) # Clip if the probability hits exactly 0
outt -= gammas[jj] * np.log(gamP)

assert np.isfinite(outt)
return outt
Expand Down Expand Up @@ -92,7 +124,7 @@ def gamma_estimator(

res = minimize(
gamma_LL,
jac="3-point",
jac="2-point",
x0=np.log(x0),
args=arrgs,
bounds=bnd,
Expand Down

0 comments on commit 06cf0fa

Please sign in to comment.