From 8ad2475d3d5d97c7e3f18a1484aca24534b9d417 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 27 Aug 2024 22:52:33 +0000 Subject: [PATCH] twoctrls6 --- bean/cli/run.py | 6 +++ bean/model/survival_model.py | 65 +++++++++++++++++++++----------- bean/preprocessing/data_class.py | 4 +- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/bean/cli/run.py b/bean/cli/run.py index 5c51c54..7dde908 100755 --- a/bean/cli/run.py +++ b/bean/cli/run.py @@ -96,6 +96,10 @@ def main(args, return_data=False): # Format bdata into data structure compatible with Pyro model bdata = prepare_bdata(bdata, args, warn, prefix) + negctrl_idx = np.where( + bdata.guides[args.negctrl_col].map(lambda s: s.lower()) + == args.negctrl_col_value.lower() + )[0] ndata = DATACLASS_DICT[args.selection][model_label]( screen=bdata, device=args.device, @@ -115,6 +119,7 @@ def main(args, return_data=False): popt=args.popt, replicate_col=args.replicate_col, use_bcmatch=(not args.ignore_bcmatch), + negctrl_guide_idx=negctrl_idx, ) guide_index = ndata.screen.guides.index.copy() assert len(guide_index) == bdata.n_obs, (len(guide_index), bdata.n_obs) @@ -122,6 +127,7 @@ def main(args, return_data=False): return ndata # Build variant dataframe adj_negctrl_idx = None + if args.library_design == "variant": if not args.uniform_edit: if "edit_rate" not in ndata.screen.guides.columns: diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index f30388e..3c7ed24 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -55,6 +55,8 @@ def NormalModel( mu_center = mu_targets mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) + if hasattr(data, "negctrl_guide_idx"): + mu[data.negctrl_guide_idx, :] = 0.0 r = torch.exp(mu) assert r.shape == (data.n_guides, 1) @@ -223,6 +225,7 @@ def MixtureNormalModel( sd_scale: float = 0.01, scale_by_accessibility: bool = False, fit_noise: bool = False, + mask_thres: int = 10, prior_params: Optional[dict] = None, ): """ @@ -262,8 +265,13 @@ def MixtureNormalModel( with pyro.plate("guide_plate0", 1): with pyro.plate("guide_plate1", data.n_targets): mu_targets = pyro.sample("mu_targets", mu_dist) + with pyro.plate("negctrl_plate", len(data.negctrl_guide_idx)): + mu_negctrl = pyro.sample("mu_negctrl", dist.Normal(0, 1)) mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) + # Fix negative control's mu to be 0 + if hasattr(data, "negctrl_guide_idx"): + mu[data.negctrl_guide_idx, :] = mu_negctrl[:, None] assert mu.shape == (data.n_guides, 2) r = torch.exp(mu) @@ -303,7 +311,7 @@ def MixtureNormalModel( data.n_guides, 2, ), pi.shape - with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2): + with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2) as t: with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): # if use_all_timepoints_for_pi: # time_pi = data.timepoints @@ -318,9 +326,11 @@ def MixtureNormalModel( # obs=data.allele_counts, # ) # else: - time_pi = data.control_timepoint + time_pi = data.control_timepoint[t] # If pi is sampled in later timepoint, account for the selection. - expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi + expanded_allele_p = pi * torch.pow( + r.expand(data.n_reps, 1, -1, -1), time_pi + ) pyro.sample( "control_allele_count", dist.Multinomial(probs=expanded_allele_p, validate_args=False), @@ -337,18 +347,28 @@ def MixtureNormalModel( assert time.shape == (data.n_condits,) with guide_plate: - alleles_p_time = torch.clamp( - time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)) - * torch.log(r).unsqueeze(0).expand((data.n_condits, -1, -1)), - max=MAX_LOGPI, - ).exp() + alleles_p_time = torch.pow( + r.unsqueeze(0).expand((data.n_condits, -1, -1)), + time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 2)), + ) + # alleles_p_time = torch.clamp( + # time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)) + # * torch.log(r).unsqueeze(0).expand((data.n_condits, -1, -1)), + # max=MAX_LOGPI, + # ).exp() + negctrl_abundance = pyro.param( + "negctrl_abundance", + torch.ones((data.n_condits,)), + constraint=constraints.positive, + ) + alleles_p_time = ( + alleles_p_time / negctrl_abundance.clamp(min=1e-5)[:, None, None] + ) assert alleles_p_time.shape == (data.n_condits, data.n_guides, 2) expected_allele_p = ( - pi.expand(data.n_reps, data.n_condits, -1, -1) - * alleles_p_time[None, :, :, :] - * q_0.unsqueeze(1).unsqueeze(-1).expand((-1, data.n_condits, -1, -1)) - ) + pi.expand(-1, data.n_condits, -1, -1) * alleles_p_time[None, :, :, :] + ) * q_0.unsqueeze(1).unsqueeze(-1).expand((-1, data.n_condits, -1, -1)) expected_guide_p = expected_allele_p.sum(axis=-1) assert expected_guide_p.shape == ( data.n_reps, @@ -359,14 +379,7 @@ def MixtureNormalModel( with replicate_plate2: with pyro.plate("guide_plate3", data.n_guides, dim=-1): a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0) - a_bcmatch = get_alpha( - expected_guide_p, - data.size_factor_bcmatch, - data.sample_mask, - data.a0_bcmatch, - ) - # a_bcmatch = get_alpha(expected_guide_p, data.size_factor_bcmatch, data.sample_mask, data.a0_bcmatch) - # assert a.shape == a_bcmatch.shape == (data.n_reps, data.n_guides, data.n_condits) + assert ( data.X.shape == data.X_bcmatch_masked.shape @@ -378,7 +391,8 @@ def MixtureNormalModel( ) with poutine.mask( mask=torch.logical_and( - data.X_masked.permute(0, 2, 1).sum(axis=-1) > 10, data.repguide_mask + data.X_masked.permute(0, 2, 1).sum(axis=-1) > mask_thres, + data.repguide_mask, ) ): pyro.sample( @@ -387,9 +401,16 @@ def MixtureNormalModel( obs=data.X_masked.permute(0, 2, 1), ) if use_bcmatch: + a_bcmatch = get_alpha( + expected_guide_p, + data.size_factor_bcmatch, + data.sample_mask, + data.a0_bcmatch, + ) with poutine.mask( mask=torch.logical_and( - data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) > 10, + data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) + > mask_thres, data.repguide_mask, ) ): diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 2b1d63c..ab4e419 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -2,7 +2,7 @@ import abc import logging from dataclasses import dataclass -from typing import Optional, Dict, Tuple, List +from typing import Optional, Dict, Tuple, List, Sequence from xmlrpc.client import Boolean from copy import deepcopy import torch @@ -49,6 +49,7 @@ def __init__( popt: Optional[Tuple[float]] = None, pi_popt: Optional[Tuple[float]] = None, control_can_be_selected: bool = False, + negctrl_guide_idx: Optional[Sequence[int]] = None, **kwargs, ): """ @@ -109,6 +110,7 @@ def __init__( self.n_samples = len(screen.samples) # 8 self.n_guides = len(screen.guides) self.n_reps = len(screen.samples[replicate_column].unique()) + self.negctrl_guide_idx = negctrl_guide_idx self.accessibility_col = accessibility_col self.accessibility_bw_path = accessibility_bw_path self.replicate_column = replicate_column