Skip to content

Commit

Permalink
twoctrls6
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Aug 27, 2024
1 parent ac070b6 commit 8ad2475
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 23 deletions.
6 changes: 6 additions & 0 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def main(args, return_data=False):

# Format bdata into data structure compatible with Pyro model
bdata = prepare_bdata(bdata, args, warn, prefix)
negctrl_idx = np.where(
bdata.guides[args.negctrl_col].map(lambda s: s.lower())
== args.negctrl_col_value.lower()
)[0]
ndata = DATACLASS_DICT[args.selection][model_label](
screen=bdata,
device=args.device,
Expand All @@ -115,13 +119,15 @@ def main(args, return_data=False):
popt=args.popt,
replicate_col=args.replicate_col,
use_bcmatch=(not args.ignore_bcmatch),
negctrl_guide_idx=negctrl_idx,
)
guide_index = ndata.screen.guides.index.copy()
assert len(guide_index) == bdata.n_obs, (len(guide_index), bdata.n_obs)
if return_data:
return ndata
# Build variant dataframe
adj_negctrl_idx = None

if args.library_design == "variant":
if not args.uniform_edit:
if "edit_rate" not in ndata.screen.guides.columns:
Expand Down
65 changes: 43 additions & 22 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def NormalModel(

mu_center = mu_targets
mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0)
if hasattr(data, "negctrl_guide_idx"):
mu[data.negctrl_guide_idx, :] = 0.0
r = torch.exp(mu)
assert r.shape == (data.n_guides, 1)

Expand Down Expand Up @@ -223,6 +225,7 @@ def MixtureNormalModel(
sd_scale: float = 0.01,
scale_by_accessibility: bool = False,
fit_noise: bool = False,
mask_thres: int = 10,
prior_params: Optional[dict] = None,
):
"""
Expand Down Expand Up @@ -262,8 +265,13 @@ def MixtureNormalModel(
with pyro.plate("guide_plate0", 1):
with pyro.plate("guide_plate1", data.n_targets):
mu_targets = pyro.sample("mu_targets", mu_dist)
with pyro.plate("negctrl_plate", len(data.negctrl_guide_idx)):
mu_negctrl = pyro.sample("mu_negctrl", dist.Normal(0, 1))
mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1)
mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0)
# Fix negative control's mu to be 0
if hasattr(data, "negctrl_guide_idx"):
mu[data.negctrl_guide_idx, :] = mu_negctrl[:, None]
assert mu.shape == (data.n_guides, 2)
r = torch.exp(mu)

Expand Down Expand Up @@ -303,7 +311,7 @@ def MixtureNormalModel(
data.n_guides,
2,
), pi.shape
with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2):
with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2) as t:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
# if use_all_timepoints_for_pi:
# time_pi = data.timepoints
Expand All @@ -318,9 +326,11 @@ def MixtureNormalModel(
# obs=data.allele_counts,
# )
# else:
time_pi = data.control_timepoint
time_pi = data.control_timepoint[t]
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
expanded_allele_p = pi * torch.pow(
r.expand(data.n_reps, 1, -1, -1), time_pi
)
pyro.sample(
"control_allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
Expand All @@ -337,18 +347,28 @@ def MixtureNormalModel(
assert time.shape == (data.n_condits,)

with guide_plate:
alleles_p_time = torch.clamp(
time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1))
* torch.log(r).unsqueeze(0).expand((data.n_condits, -1, -1)),
max=MAX_LOGPI,
).exp()
alleles_p_time = torch.pow(
r.unsqueeze(0).expand((data.n_condits, -1, -1)),
time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 2)),
)
# alleles_p_time = torch.clamp(
# time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1))
# * torch.log(r).unsqueeze(0).expand((data.n_condits, -1, -1)),
# max=MAX_LOGPI,
# ).exp()
negctrl_abundance = pyro.param(
"negctrl_abundance",
torch.ones((data.n_condits,)),
constraint=constraints.positive,
)
alleles_p_time = (
alleles_p_time / negctrl_abundance.clamp(min=1e-5)[:, None, None]
)
assert alleles_p_time.shape == (data.n_condits, data.n_guides, 2)

expected_allele_p = (
pi.expand(data.n_reps, data.n_condits, -1, -1)
* alleles_p_time[None, :, :, :]
* q_0.unsqueeze(1).unsqueeze(-1).expand((-1, data.n_condits, -1, -1))
)
pi.expand(-1, data.n_condits, -1, -1) * alleles_p_time[None, :, :, :]
) * q_0.unsqueeze(1).unsqueeze(-1).expand((-1, data.n_condits, -1, -1))
expected_guide_p = expected_allele_p.sum(axis=-1)
assert expected_guide_p.shape == (
data.n_reps,
Expand All @@ -359,14 +379,7 @@ def MixtureNormalModel(
with replicate_plate2:
with pyro.plate("guide_plate3", data.n_guides, dim=-1):
a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0)
a_bcmatch = get_alpha(
expected_guide_p,
data.size_factor_bcmatch,
data.sample_mask,
data.a0_bcmatch,
)
# a_bcmatch = get_alpha(expected_guide_p, data.size_factor_bcmatch, data.sample_mask, data.a0_bcmatch)
# assert a.shape == a_bcmatch.shape == (data.n_reps, data.n_guides, data.n_condits)

assert (
data.X.shape
== data.X_bcmatch_masked.shape
Expand All @@ -378,7 +391,8 @@ def MixtureNormalModel(
)
with poutine.mask(
mask=torch.logical_and(
data.X_masked.permute(0, 2, 1).sum(axis=-1) > 10, data.repguide_mask
data.X_masked.permute(0, 2, 1).sum(axis=-1) > mask_thres,
data.repguide_mask,
)
):
pyro.sample(
Expand All @@ -387,9 +401,16 @@ def MixtureNormalModel(
obs=data.X_masked.permute(0, 2, 1),
)
if use_bcmatch:
a_bcmatch = get_alpha(
expected_guide_p,
data.size_factor_bcmatch,
data.sample_mask,
data.a0_bcmatch,
)
with poutine.mask(
mask=torch.logical_and(
data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) > 10,
data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1)
> mask_thres,
data.repguide_mask,
)
):
Expand Down
4 changes: 3 additions & 1 deletion bean/preprocessing/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import abc
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Tuple, List
from typing import Optional, Dict, Tuple, List, Sequence
from xmlrpc.client import Boolean
from copy import deepcopy
import torch
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(
popt: Optional[Tuple[float]] = None,
pi_popt: Optional[Tuple[float]] = None,
control_can_be_selected: bool = False,
negctrl_guide_idx: Optional[Sequence[int]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
self.n_samples = len(screen.samples) # 8
self.n_guides = len(screen.guides)
self.n_reps = len(screen.samples[replicate_column].unique())
self.negctrl_guide_idx = negctrl_guide_idx
self.accessibility_col = accessibility_col
self.accessibility_bw_path = accessibility_bw_path
self.replicate_column = replicate_column
Expand Down

0 comments on commit 8ad2475

Please sign in to comment.