From 568c599410e2bae461530da6b80bcb0b15739a53 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Thu, 9 Nov 2023 11:00:28 -0500 Subject: [PATCH 01/13] initial commit for proliferation screen --- README.md | 4 +- bean/model/model.py | 1 - bean/model/readwrite.py | 39 ++-- bean/model/utils.py | 346 ------------------------------- bean/preprocessing/data_class.py | 345 +++++++++++++++++++++++------- bean/preprocessing/utils.py | 30 ++- bin/bean-run | 124 ++--------- 7 files changed, 345 insertions(+), 544 deletions(-) diff --git a/README.md b/README.md index 06a0fa0..ef0364a 100644 --- a/README.md +++ b/README.md @@ -248,7 +248,7 @@ bean-filter my_sorting_screen.h5ad \ ## `bean-run`: Quantify variant effects BEAN uses Bayesian network to incorporate gRNA editing outcome to provide posterior estimate of variant phenotype. The Bayesian network reflects data generation process. Briefly, -1. Cellular phenotype is modeled as the Gaussian mixture distribution of wild-type phenotype and variant phenotype. +1. Cellular phenotype (either for cells are sorted upon for sorting screen, or log(proliferation rate)) is modeled as the Gaussian mixture distribution of wild-type phenotype and variant phenotype. 2. The weight of the mixture components are inferred from the reporter editing outcome and the chromatin accessibility of the loci. 3. Cells with each gRNA, formulated as the mixture distribution, is sorted by the phenotypic quantile to produce the gRNA counts. @@ -258,7 +258,7 @@ For the full detail, see the method section of the [BEAN manuscript](https://www

```bash -bean-run variant[tiling] my_sorting_screen_filtered.h5ad \ +bean-run sorting[survival] variant[tiling] my_sorting_screen_filtered.h5ad \ [--uniform-edit, --scale-by-acc [--acc-bw-path accessibility_signal.bw, --acc-col accessibility]] \ -o output_prefix/ \ --fit-negctrl diff --git a/bean/model/model.py b/bean/model/model.py index a5c3b66..c6be6e9 100644 --- a/bean/model/model.py +++ b/bean/model/model.py @@ -91,7 +91,6 @@ def NormalModel( obs=data.X_masked.permute(0, 2, 1), ) if use_bcmatch: - print(f"Use_bcmatch:{use_bcmatch}") a_bcmatch = get_alpha( expected_guide_p, data.size_factor_bcmatch, diff --git a/bean/model/readwrite.py b/bean/model/readwrite.py index fc34d11..b41de39 100644 --- a/bean/model/readwrite.py +++ b/bean/model/readwrite.py @@ -57,41 +57,50 @@ def write_result_table( adjust_confidence_negatives: np.ndarray = None, guide_index: Optional[Sequence[str]] = None, guide_acc: Optional[Sequence] = None, + sd_is_fitted: bool = True, return_result: bool = False, ) -> Union[pd.DataFrame, None]: """Combine target information and scores to write result table to a csv file or return it.""" if param_hist_dict["params"]["mu_loc"].dim() == 2: mu = param_hist_dict["params"]["mu_loc"].detach()[:, 0].cpu().numpy() mu_sd = param_hist_dict["params"]["mu_scale"].detach()[:, 0].cpu().numpy() - sd = param_hist_dict["params"]["sd_loc"].detach().exp()[:, 0].cpu().numpy() + if sd_is_fitted: + sd = param_hist_dict["params"]["sd_loc"].detach().exp()[:, 0].cpu().numpy() elif param_hist_dict["params"]["mu_loc"].dim() == 1: mu = param_hist_dict["params"]["mu_loc"].detach().cpu().numpy() mu_sd = param_hist_dict["params"]["mu_scale"].detach().cpu().numpy() - sd = param_hist_dict["params"]["sd_loc"].detach().exp().cpu().numpy() + if sd_is_fitted: + sd = param_hist_dict["params"]["sd_loc"].detach().exp().cpu().numpy() else: raise ValueError( f'`mu_loc` has invalid shape of {param_hist_dict["params"]["mu_loc"].shape}' ) - fit_df = pd.DataFrame( - { - "mu": mu, - "mu_sd": mu_sd, - "mu_z": mu / mu_sd, - "sd": sd, - } - ) + param_dict = { + "mu": mu, + "mu_sd": mu_sd, + "mu_z": mu / mu_sd, + } + if sd_is_fitted: + param_dict["sd"] = sd + fit_df = pd.DataFrame(param_dict) fit_df["novl"] = get_novl(fit_df, "mu", "mu_sd") if "negctrl" in param_hist_dict.keys(): print("Normalizing with common negative control distribution") mu0 = param_hist_dict["negctrl"]["params"]["mu_loc"].detach().cpu().numpy() - sd0 = ( - param_hist_dict["negctrl"]["params"]["sd_loc"].detach().exp().cpu().numpy() - ) - print(f"Fitted mu0={mu0}, sd0={sd0}.") + if sd_is_fitted: + sd0 = ( + param_hist_dict["negctrl"]["params"]["sd_loc"] + .detach() + .exp() + .cpu() + .numpy() + ) + print(f"Fitted mu0={mu0}" + (f", sd0={sd0}." if sd_is_fitted else "")) fit_df["mu_scaled"] = (mu - mu0) / sd0 fit_df["mu_sd_scaled"] = mu_sd / sd0 fit_df["mu_z_scaled"] = fit_df.mu_scaled / fit_df.mu_sd_scaled - fit_df["sd_scaled"] = sd / sd0 + if sd_is_fitted: + fit_df["sd_scaled"] = sd / sd0 fit_df["novl_scaled"] = get_novl(fit_df, "mu_scaled", "mu_sd_scaled") fit_df = pd.concat( diff --git a/bean/model/utils.py b/bean/model/utils.py index 10f059c..7767232 100644 --- a/bean/model/utils.py +++ b/bean/model/utils.py @@ -1,355 +1,9 @@ -import os -import sys -import argparse -from tqdm import tqdm -import logging -import pickle as pkl -import pandas as pd import torch import torch.distributions as tdist import pyro import pyro.distributions as dist import pyro.distributions.constraints as constraints -logging.basicConfig( - level=logging.INFO, - format="%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n", - datefmt="%a, %d %b %Y %H:%M:%S", - stream=sys.stderr, - filemode="w", -) -error = logging.critical -warn = logging.warning -debug = logging.debug -info = logging.info -pyro.set_rng_seed(101) - - -def run_inference( - model, guide, data, initial_lr=0.01, gamma=0.1, num_steps=2000, autoguide=False -): - pyro.clear_param_store() - lrd = gamma ** (1 / num_steps) - svi = pyro.infer.SVI( - model=model, - guide=guide, - optim=pyro.optim.ClippedAdam({"lr": initial_lr, "lrd": lrd}), - loss=pyro.infer.Trace_ELBO(), - ) - losses = [] - try: - for t in tqdm(range(num_steps)): - loss = svi.step(data) - if t % 100 == 0: - print(f"loss {loss} @ iter {t}") - losses.append(loss) - except ValueError as exc: - error( - "Error occurred during fitting. Saving temporary output at tmp_result.pkl." - ) - with open("tmp_result.pkl", "wb") as handle: - pkl.dump({"param": pyro.get_param_store()}, handle) - - raise ValueError( - f"Fitting halted for command: {' '.join(sys.argv)} with following error: \n {exc}" - ) - return { - "loss": losses, - "params": pyro.get_param_store(), - } - - -def _get_guide_target_info(bdata, args, cols_include=[]): - guide_info = bdata.guides.copy() - target_info = ( - guide_info[ - [args.target_col] - + [ - col - for col in guide_info.columns - if ( - ( - (col.startswith("target_")) - and len(guide_info[[args.target_col, col]].drop_duplicates()) - == len(guide_info[args.target_col].unique()) - ) - or col in cols_include - ) - and col != args.target_col - ] - ] - .drop_duplicates() - .set_index(args.target_col, drop=True) - ) - if "edit_rate" in guide_info.columns.tolist(): - edit_rate_info = ( - guide_info[[args.target_col, "edit_rate"]] - .groupby(args.target_col, sort=False) - .agg({"edit_rate": ["mean", "std"]}) - ) - edit_rate_info.columns = edit_rate_info.columns.get_level_values(1) - edit_rate_info = edit_rate_info.rename( - columns={"mean": "edit_rate_mean", "std": "edit_rate_std"} - ) - target_info = target_info.join(edit_rate_info) - return target_info - - -def none_or_str(value): - if value == "None": - return None - return value - - -def parse_args(): - print( - r""" - _ _ - / \ '\ - | \ \ _ _ _ _ _ _ - \ \ | | '_| || | ' \ - `.__|/ |_| \_,_|_||_| - """ - ) - print("bean-run: Run model to identify targeted variants and their impact.") - parser = argparse.ArgumentParser(description="Run model on data.") - - parser.add_argument( - "mode", - type=str, - help="[variant, tiling]- Screen type whether to run variant or tiling screen model.", - ) - parser.add_argument("bdata_path", type=str, help="Path of an ReporterScreen object") - parser.add_argument( - "--rep-pi", - "-r", - action="store_true", - default=False, - help="Fit replicate specific scaling factor. Recommended to set as True if you expect variable editing activity across biological replicates.", - ) - parser.add_argument( - "--uniform-edit", - "-p", - action="store_true", - default=False, - help="Assume uniform editing rate for all guides.", - ) - parser.add_argument( - "--scale-by-acc", - action="store_true", - default=False, - help="Scale guide editing efficiency by the target loci accessibility", - ) - parser.add_argument( - "--acc-bw-path", - type=str, - default=None, - help="Accessibility .bigWig file to be used to assign accessibility of guides.", - ) - parser.add_argument( - "--acc-col", - type=str, - default=None, - help="Column name in bdata.guides that specify raw ATAC-seq signal.", - ) - parser.add_argument( - "--const-pi", - default=False, - action="store_true", - help="Use constant pi provided in --guide-activity-col (instead of fitting from reporter data)", - ) - parser.add_argument( - "--shrink-alpha", - default=False, - action="store_true", - help="Instead of using the trend-fitted alpha values, use estimated alpha values for each gRNA that are shrunk towards the fitted trend.", - ) - parser.add_argument( - "--condition-col", - default="bin", - type=str, - help="Column key in `bdata.samples` that describes experimental condition.", - ) - parser.add_argument( - "--control-condition-label", - default="bulk", - type=str, - help="Value in `bdata.samples[condition_col]` that indicates control experimental condition.", - ) - parser.add_argument( - "--replicate-col", - default="rep", - type=str, - help="Column key in `bdata.samples` that describes experimental replicates.", - ) - parser.add_argument( - "--target-col", - default="target", - type=str, - help="Column key in `bdata.guides` that describes the target element of each guide.", - ) - parser.add_argument( - "--guide-activity-col", - "-a", - type=str, - default=None, - help="Column in ReporterScreen.guides DataFrame showing the editing rate estimated via external tools", - ) - parser.add_argument( - "--outdir", - "-o", - default=".", - type=str, - help="Directory to save the run result.", - ) - parser.add_argument( - "--result-suffix", - default="", - type=str, - help="Suffix of the output files", - ) - parser.add_argument( - "--sorting-bin-upper-quantile-col", - "-uq", - help="Column name with upper quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData.var)", - default="upper_quantile", - ) - parser.add_argument( - "--sorting-bin-lower-quantile-col", - "-lq", - help="Column name with lower quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData var)", - default="lower_quantile", - ) - parser.add_argument("--cuda", action="store_true", default=False, help="run on GPU") - parser.add_argument( - "--sample-mask-col", - type=str, - default=None, - help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.", - ) - parser.add_argument( - "--fit-negctrl", - action="store_true", - default=False, - help="Fit the shared negative control distribution to normalize the fitted parameters", - ) - parser.add_argument( - "--negctrl-col", - type=str, - default="target_group", - help="Column in bdata.obs specifying if a guide is negative control. If the `bdata.guides[negctrl_col].lower() == negctrl_col_value`, it is treated as negative control guide.", - ) - parser.add_argument( - "--negctrl-col-value", - type=str, - default="negctrl", - help="Column value in bdata.guides specifying if a guide is negative control. If the `bdata.guides[negctrl_col].lower() == negctrl_col_value`, it is treated as negative control guide.", - ) - parser.add_argument( - "--repguide-mask", - type=none_or_str, - default="repguide_mask", - help="n_replicate x n_guide mask to mask the outlier guides. screen.uns[repguide_mask] will be used.", - ) - parser.add_argument( - "--device", - type=str, - default=None, - help="Optionally use GPU if provided valid GPU device name (ex. cuda:0)", - ) - parser.add_argument( - "--ignore-bcmatch", - action="store_true", - default=False, - help="If provided, even if the screen object has .X_bcmatch, ignore the count when fitting.", - ) - parser.add_argument( - "--allele-df-key", - type=str, - default=None, - help="screen.uns[allele_df_key] will be used as the allele count.", - ) - parser.add_argument( - "--splice-site-path", - type=str, - default=None, - help="Path to splicing site", - ) - parser.add_argument( - "--control-guide-tag", - type=none_or_str, - default="CONTROL", - help="If this string is in guide name, treat each guide separately not to mix the position. Used for negative controls.", - ) - parser.add_argument( - "--dont-fit-noise", # TODO: add check args - action="store_true", - ) - parser.add_argument( - "--dont-adjust-confidence-by-negative-control", - action="store_true", - help="Adjust confidence by negative controls. For variant mode, this uses negative control variants. For tiling mode, adjusts confidence by synonymous edits.", - ) - parser.add_argument( - "--load-existing", # TODO: add check args - action="store_true", - help="Load existing .pkl file if present.", - ) - - return parser.parse_args() - - -def check_args(args, bdata): - args.adjust_confidence_by_negative_control = ( - not args.dont_adjust_confidence_by_negative_control - ) - if args.scale_by_acc: - if args.acc_col is None and args.acc_bw_path is None: - raise ValueError( - "--scale-by-acc not accompanied by --acc-col nor --acc-bw-path to use. Pass either one." - ) - elif args.acc_col is not None and args.acc_bw_path is not None: - warn( - "Both --acc-col and --acc-bw-path is specified. --acc-bw-path is ignored." - ) - args.acc_bw_path = None - if args.outdir is None: - args.outdir = os.path.dirname(args.bdata_path) - if args.mode == "variant": - pass - elif args.mode == "tiling": - if args.allele_df_key is None: - raise ValueError( - "--allele-df-key not provided for tiling screen. Feed in the key then allele counts in screen.uns[allele_df_key] will be used." - ) - else: - raise ValueError( - "Invalid mode provided. Select either 'variant' or 'tiling'." - ) # TODO: change this into discrete modes via argparse - if args.fit_negctrl: - n_negctrl = ( - bdata.guides[args.negctrl_col].map(lambda s: s.lower()) - == args.negctrl_col_value.lower() - ).sum() - if not n_negctrl >= 20: - raise ValueError( - f"Not enough negative control guide in the input data: {n_negctrl}. Check your input arguments." - ) - if args.repguide_mask is not None and args.repguide_mask not in bdata.uns.keys(): - bdata.uns[args.repguide_mask] = pd.DataFrame( - index=bdata.guides.index, columns=bdata.samples[args.replicate_col].unique() - ).fillna(1) - warn( - f"{args.bdata_path} does not have replicate x guide outlier mask. All guides are included in analysis." - ) - if args.sample_mask_col is not None: - if args.sample_mask_col not in bdata.samples.columns.tolist(): - raise ValueError( - f"{args.bdata_path} does not have specified sample mask column {args.sample_mask_col} in .samples" - ) - - return args, bdata - def get_alpha( expected_guide_p, size_factor, sample_mask, a0, epsilon=1e-5, normalize_by_a0=True diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index ea4953e..e3c518d 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -1,8 +1,8 @@ import sys import abc +import logging from dataclasses import dataclass from typing import Dict, Tuple -import logging from xmlrpc.client import Boolean from copy import deepcopy import torch @@ -13,7 +13,6 @@ from .get_pi_alpha0 import get_fitted_alpha0 as get_fitted_pi_alpha0 from .get_pi_alpha0 import get_pred_alpha0 as get_pred_pi_alpha0 from .utils import ( - Alias, get_accessibility_guides, get_edit_to_index_dict, _assign_rep_ids_and_sort, @@ -47,8 +46,13 @@ def __init__( device: str = None, replicate_column: str = "rep", pi_popt: Tuple[float] = None, + control_can_be_selected: bool = False, **kwargs, ): + """ + Args + control_can_be_selected: If True, screen.samples[condition_column] == control_condition can also be included in effect size inference if its condition column is not NA (Currently only suppoted for prolifertion screens). + """ # TODO: remove replicate with too small number of (ex. only 1) sorting bin self.condition_column = condition_column self.device = device @@ -66,9 +70,12 @@ def __init__( ) self.screen = screen - self.screen_selected = screen[ - :, screen.samples[condition_column] != control_condition - ] + if not control_can_be_selected: + self.screen_selected = screen[ + :, screen.samples[condition_column] != control_condition + ] + else: + self.screen_selected = screen[:, ~screen.samples[condition_column].isnull()] self.n_condits = len( self.screen_selected.var[condition_column].unique() @@ -178,7 +185,16 @@ def __getitem__(self, guide_idx): def transform_data(self, X, n_bins=None): if n_bins is None: n_bins = self.n_condits - x = torch.as_tensor(X).T.reshape((self.n_reps, n_bins, self.n_guides)).float() + try: + x = ( + torch.as_tensor(X) + .T.reshape((self.n_reps, n_bins, self.n_guides)) + .float() + ) + except RuntimeError: + print((self.n_reps, n_bins, self.n_guides)) + print(X.shape) + exit(1) if self.device is not None: x = x.cuda() return x @@ -668,7 +684,7 @@ def transform_allele(self, adata, reindexed_df): allele_tensor = torch.empty( (self.n_reps, self.n_condits, self.n_guides, self.n_max_alleles), ) - if not self.device is None: + if self.device is not None: allele_tensor = allele_tensor.cuda() for i in range(self.n_reps): for j in range(self.n_condits): @@ -710,7 +726,7 @@ def transform_allele(self, adata, reindexed_df): try: assert (allele_tensor >= 0).all(), allele_tensor[allele_tensor < 0] - except: + except AssertionError: print("Allele tensor doesn't match condit_allele_df") return (allele_tensor, reindexed_df) return allele_tensor @@ -724,7 +740,7 @@ def transform_allele_control(self, adata, reindexed_df): allele_tensor = torch.empty( (self.n_reps, 1, self.n_guides, self.n_max_alleles), ) - if not self.device is None: + if self.device is not None: allele_tensor = allele_tensor.cuda() for i in range(self.n_reps): condit_idx = np.where(adata.samples.rep_id == i)[0] @@ -761,14 +777,10 @@ def transform_allele_control(self, adata, reindexed_df): self.n_guides, self.n_max_alleles, ) - try: - allele_tensor[i, 0, :, :] = torch.as_tensor(condit_allele_df.to_numpy()) - except: - print("Allele tensor doesn't match condit_allele_df") - return (allele_tensor, torch.as_tensor(condit_allele_df.to_numpy())) + allele_tensor[i, 0, :, :] = torch.as_tensor(condit_allele_df.to_numpy()) try: assert (allele_tensor >= 0).all(), allele_tensor[allele_tensor < 0] - except: + except AssertionError: print("Negative values in allele_tensor") return (allele_tensor, reindexed_df) return allele_tensor @@ -796,48 +808,6 @@ def get_allele_mask( mask[i, j + 1] = 1 return mask.bool() - def get_allele_to_edit_tensor( - self, - adata, - edits_to_index: Dict[str, int], - guide_allele_id_to_allele_df: pd.DataFrame, - ): - """ - Convert (guide, allele_id_for_guide) -> allele DataFrame into the tensor with shape (n_guides, n_max_alleles_per_guide, n_edits) tensor. - ----- - Arguments - edits_to_index (dict) -- Dictionary from edit (str) to unique index (int) - guide_allele_id_to_allele_df (pd.DataFrame) -- pd.DataFrame of (guide(str), allele_id_for_guide(int)) -> CodingNoncodingAllele - ----- - Returns - allele_edit_assignment (torch.Tensor) -- Binary tensor of shape (n_guides, n_max_alleles_per_guide, n_edits). - allele_edit_assignment(i, j, k) is 1 if jth allele of ith guide has kth edit. - """ - guide_allele_id_to_allele_df[ - "edits" - ] = guide_allele_id_to_allele_df.aa_allele.map( - lambda a: list(a.aa_allele.edits) + list(a.nt_allele.edits) - ) - guide_allele_id_to_allele_df = guide_allele_id_to_allele_df.reset_index() - guide_allele_id_to_allele_df[ - "edit_idx" - ] = guide_allele_id_to_allele_df.edits.map( - lambda es: [edits_to_index[e.get_abs_edit()] for e in es] - ) - guide_allele_id_to_edit_df = guide_allele_id_to_allele_df[ - ["guide", "allele_id_for_guide", "edit_idx"] - ].set_index(["guide", "allele_id_for_guide"]) - guide_allele_id_to_edit_df = guide_allele_id_to_edit_df.unstack( - level=1, fill_value=[] - ).reindex(adata.guides.index, fill_value=[]) - allele_edit_assignment = torch.zeros( - (len(adata.guides), self.n_max_alleles - 1, len(edits_to_index.keys())) - ) - for i in range(len(guide_allele_id_to_edit_df)): - for j in range(len(guide_allele_id_to_edit_df.columns)): - allele_edit_assignment[i, j, guide_allele_id_to_edit_df.iloc[i, j]] = 1 - return allele_edit_assignment - @dataclass class SortingScreenData(ScreenData): @@ -909,7 +879,7 @@ def _pre_init( == len(self.screen.samples[self.replicate_column].unique()) ).all(): raise ValueError( - "Not all replicate share same quantile bin definition. If you have missing bin data, add the sample and add 'mask' column in 'screen.samples'." + "Not all replicate share same quantile bin definition. If you have missing bin data, add the sample and add 'mask' column in 'screen.samples' or run `bean-qc` that automatically handles this." ) sorting_bins = self.screen_selected.samples.sort_values( [upper_quantile_column, lower_quantile_column] @@ -952,12 +922,14 @@ def __init__( repguide_mask: str = None, sample_mask_column: str = None, shrink_alpha: bool = False, - condition_column: str = "time", + condition_column: str = "condition", control_condition: str = "bulk", - accessibility_col: str = None, - accessibility_bw_path: str = None, + control_can_be_selected=True, + time_column: str = "time", + replicate_column: str = "rep", **kwargs, ): + self._pre_init(condition_column) super().__init__( screen=screen, repguide_mask=repguide_mask, @@ -965,14 +937,59 @@ def __init__( shrink_alpha=shrink_alpha, condition_column=condition_column, control_condition=control_condition, - accessibility_col=accessibility_col, - accessibility_bw_path=accessibility_bw_path, + control_can_be_selected=control_can_be_selected, **kwargs, ) + self._post_init() + + def _pre_init(self, time_column: str, condition_column: str): + self.time_column = time_column + if not np.issubdtype(self.screen.samples[time_column].dtype, np.number): + raise ValueError( + f"Invalid timepoint value({self.screen.samples[time_column]}) in screen.samples[{time_column}]: check input." + ) + + if not ( + self.screen.samples.groupby(condition_column).size() + == len(self.screen.samples[self.replicate_column].unique()) + ).all(): + raise ValueError( + f"Not all replicate share same timepoint definition. If you have missing bin data, add the sample and add 'mask' column in 'screen.samples', or run `bean-qc` that automatically handles this. \n{self.screen.samples}" + ) + + def _post_init( + self, + ): self.timepoints = torch.as_tensor( - self.screen.samples[condition_column].unique() + self.screen_selected.samples[self.time_column].unique() + ) + self.n_timepoints = self.n_condits + timepoints = self.screen_selected.samples.sort_values(self.time_column)[ + self.time_column + ].drop_duplicates() + if timepoints.isnull().any(): + raise ValueError( + f"NaN values in time points provided in input: {self.screen_selected.samples[self.time_column]}" + ) + for j, time in enumerate(timepoints): + self.screen_selected.samples.loc[ + self.screen_selected.samples[self.time_column] == time, + f"{self.time_column}_id", + ] = j + self.screen.samples[f"{self.time_column}_id"] = -1 + self.screen.samples.loc[ + self.screen_selected.samples.index, f"{self.time_column}_id" + ] = self.screen_selected.samples[f"{self.time_column}_id"] + self.screen = _assign_rep_ids_and_sort( + self.screen, self.replicate_column, self.time_column + ) + self.screen_selected = _assign_rep_ids_and_sort( + self.screen_selected, self.replicate_column, self.time_column + ) + self.screen_control = _assign_rep_ids_and_sort( + self.screen_control, + self.replicate_column, ) - self.timepoints = Alias("n_condits") @dataclass @@ -985,7 +1002,7 @@ def __init__( pi_popt: Tuple[float] = None, impute_pi_popt: bool = False, shrink_alpha: bool = False, - condition_column: str = "time", + condition_column: str = "condition", control_condition: str = "bulk", accessibility_col: str = None, accessibility_bw_path: str = None, @@ -1036,8 +1053,8 @@ def __init__( screen, *args, sample_mask_column=sample_mask_column, - replicate_column="rep", - condition_column="bin", + replicate_column=replicate_column, + condition_column=condition_column, shrink_alpha=shrink_alpha, **kwargs, ) @@ -1110,8 +1127,8 @@ def __init__( screen, *args, sample_mask_column=sample_mask_column, - replicate_column="rep", - condition_column="bin", + replicate_column=replicate_column, + condition_column=condition_column, shrink_alpha=shrink_alpha, **kwargs, ) @@ -1159,8 +1176,8 @@ def __init__( screen, *args, sample_mask_column=sample_mask_column, - replicate_column="rep", - condition_column="bin", + replicate_column=replicate_column, + condition_column=condition_column, shrink_alpha=shrink_alpha, **kwargs, ) @@ -1186,6 +1203,188 @@ def __init__( ) +@dataclass +class VariantSurvivalScreenData(VariantScreenData, SurvivalScreenData): + def __init__( + self, + screen, + *args, + replicate_column="rep", + condition_column="condition", + time_column="time", + control_can_be_selected=True, + target_col="target", + sample_mask_column="mask", + shrink_alpha: bool = False, + use_bcmatch=False, + **kwargs, + ): + ScreenData.__init__( + self, + screen, + *args, + sample_mask_column=sample_mask_column, + replicate_column=replicate_column, + condition_column=condition_column, + time_column=time_column, + shrink_alpha=shrink_alpha, + control_can_be_selected=control_can_be_selected, + **kwargs, + ) + SurvivalScreenData._pre_init(self, time_column, condition_column) + ScreenData._post_init(self) + SurvivalScreenData._post_init(self) + VariantScreenData._post_init(self, target_col) + if use_bcmatch: + self.set_bcmatch( + screen, + ) + + def set_bcmatch(self, screen): + screen.samples["size_factor_bcmatch"] = self.get_size_factor( + screen.layers["X_bcmatch"] + ) + self.screen_selected.samples["size_factor_bcmatch"] = screen.samples.loc[ + self.screen_selected.samples.index, "size_factor_bcmatch" + ] + self.screen_control.samples["size_factor_bcmatch"] = screen.samples.loc[ + self.screen_control.samples.index, "size_factor_bcmatch" + ] + self.X_bcmatch = self.transform_data(self.screen_selected.layers["X_bcmatch"]) + self.X_bcmatch_masked = self.X_bcmatch * self.sample_mask[:, :, None] + self.X_bcmatch_control = self.transform_data( + self.screen_control.layers["X_bcmatch"], 1 + ) + self.X_bcmatch_control_masked = ( + self.X_bcmatch_control * self.bulk_sample_mask[:, None, None] + ) + self.size_factor_bcmatch = torch.as_tensor( + self.screen_selected.samples["size_factor_bcmatch"].to_numpy() + ).reshape(self.n_reps, self.n_condits) + self.size_factor_bcmatch_control = torch.as_tensor( + self.screen_control.samples["size_factor_bcmatch"].to_numpy() + ).reshape(self.n_reps, 1) + a0_bcmatch = get_pred_alpha0( + self.X_bcmatch.clone().cpu(), + self.size_factor_bcmatch.clone().cpu(), + self.popt, + self.sample_mask.cpu(), + ) + self.a0_bcmatch = torch.as_tensor(a0_bcmatch) + + @dataclass class VariantSurvivalReporterScreenData(VariantReporterScreenData, SurvivalScreenData): - pass + def __init__( + self, + screen, + *args, + replicate_column="rep", + condition_column="condition", + time_column="time", + control_can_be_selected=True, + target_col="target", + sample_mask_column="mask", + use_const_pi: bool = False, + impute_pi_popt: bool = False, + pi_prior_count: int = 10, + shrink_alpha: bool = False, + pi_popt: Tuple[float] = None, + **kwargs, + ): + ScreenData.__init__( + self, + screen, + *args, + sample_mask_column=sample_mask_column, + replicate_column=replicate_column, + condition_column=condition_column, + time_column=time_column, + shrink_alpha=shrink_alpha, + control_can_be_selected=control_can_be_selected, + **kwargs, + ) + SurvivalScreenData._pre_init(self, time_column, condition_column) + ScreenData._post_init(self) + SurvivalScreenData._post_init(self) + VariantScreenData._post_init(self, target_col) + ReporterScreenData._post_init( + self, + screen, + use_const_pi, + impute_pi_popt, + pi_prior_count, + shrink_alpha, + pi_popt, + ) + + +@dataclass +class TilingSurvivalReporterScreenData(TilingReporterScreenData, SurvivalScreenData): + def __init__( + self, + screen, + *args, + replicate_column="rep", + condition_column="condition", + time_column="time", + control_can_be_selected=True, + sample_mask_column="mask", + use_const_pi: bool = False, + impute_pi_popt: bool = False, + pi_prior_count: int = 10, + shrink_alpha: bool = False, + pi_popt: Tuple[float] = None, + allele_df_key: str = None, + allele_col: str = None, + control_guide_tag: str = None, + **kwargs, + ): + ScreenData.__init__( + self, + screen, + *args, + sample_mask_column=sample_mask_column, + replicate_column=replicate_column, + condition_column=condition_column, + time_column=time_column, + shrink_alpha=shrink_alpha, + control_can_be_selected=control_can_be_selected, + **kwargs, + ) + SurvivalScreenData._pre_init(self, time_column, condition_column) + ScreenData._post_init(self) + SurvivalScreenData._post_init(self) + TilingReporterScreenData._post_init( + self, + allele_df_key=allele_df_key, + control_guide_tag=control_guide_tag, + ) + ReporterScreenData._post_init( + self, + screen, + use_const_pi, + impute_pi_popt, + pi_prior_count, + shrink_alpha, + pi_popt, + ) + + +DATACLASS_DICT = { + "sorting": { + "Normal": VariantSortingScreenData, + "MixtureNormal": VariantSortingReporterScreenData, + "MixtureNormal+Acc": VariantSortingReporterScreenData, + "MixtureNormalConstPi": VariantSortingScreenData, + "MultiMixtureNormal": TilingSortingReporterScreenData, + "MultiMixtureNormal+Acc": TilingSortingReporterScreenData, + }, + "survival": { + "Normal": VariantSurvivalScreenData, + "MixtureNormal": VariantSurvivalReporterScreenData, + "MixtureNormal+Acc": VariantSurvivalReporterScreenData, + "MultiMixtureNormal": TilingSurvivalReporterScreenData, + "MultiMixtureNormal+Acc": TilingSurvivalReporterScreenData, + }, +} diff --git a/bean/preprocessing/utils.py b/bean/preprocessing/utils.py index 1972fa5..ebea9ab 100644 --- a/bean/preprocessing/utils.py +++ b/bean/preprocessing/utils.py @@ -3,8 +3,8 @@ import numpy as np import pyBigWig import pandas as pd -import anndata as ad import bean as be +from bean.qc.guide_qc import filter_no_info_target class Alias: @@ -21,6 +21,34 @@ def __set__(self, obj, value): setattr(obj, self.source_name, value) +def prepare_bdata(bdata: be.ReporterScreen, args, warn, prefix: str): + """Utility function for formatting bdata for bean-run""" + bdata = bdata.copy() + bdata.samples[args.replicate_col] = bdata.samples[args.replicate_col].astype( + "category" + ) + bdata.guides = bdata.guides.loc[:, ~bdata.guides.columns.duplicated()].copy() + if args.library_design == "variant": + if bdata.guides[args.target_col].isnull().any(): + raise ValueError( + f"Some target column (bdata.guides[{args.target_col}]) value is null. Check your input file." + ) + bdata = bdata[bdata.guides[args.target_col].argsort(), :] + n_no_support_targets, bdata = filter_no_info_target( + bdata, + condit_col=args.condition_col, + control_condition=args.control_condition_label, + target_col=args.target_col, + write_no_support_targets=True, + no_support_target_write_path=f"{prefix}/no_support_targets.csv", + ) + if n_no_support_targets > 0: + warn( + f"Ignoring {n_no_support_targets} targets with 0 gRNA counts across all non-control samples. Ignored targets are written in {prefix}/no_support_targets.csv." + ) + return bdata + + def _get_accessibility_single( pos: int, track, diff --git a/bin/bean-run b/bin/bean-run index 1b358ee..6a703a3 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -3,7 +3,6 @@ import os import sys import logging from copy import deepcopy -from functools import partial import numpy as np import pandas as pd @@ -13,25 +12,22 @@ import pyro.infer import pyro.optim import pickle as pkl -from bean.qc.guide_qc import filter_no_info_target import bean.model.model as m from bean.model.readwrite import write_result_table -from bean.preprocessing.data_class import ( - VariantSortingScreenData, - VariantSortingReporterScreenData, - TilingSortingReporterScreenData, -) +from bean.preprocessing.data_class import DATACLASS_DICT from bean.preprocessing.utils import ( + prepare_bdata, _obtain_effective_edit_rate, _obtain_n_guides_alleles_per_variant, ) import bean as be -from bean.model.utils import ( +from bean.model.run import ( run_inference, _get_guide_target_info, parse_args, check_args, + identify_model_guide, ) logging.basicConfig( @@ -47,73 +43,6 @@ debug = logging.debug info = logging.info pyro.set_rng_seed(101) -DATACLASS_DICT = { - "Normal": VariantSortingScreenData, - "MixtureNormal": VariantSortingReporterScreenData, - "_MixtureNormal+Acc": VariantSortingReporterScreenData, # TODO: old - "MixtureNormal+Acc": VariantSortingReporterScreenData, - "MixtureNormalConstPi": VariantSortingScreenData, - "MultiMixtureNormal": TilingSortingReporterScreenData, - "MultiMixtureNormal+Acc": TilingSortingReporterScreenData, -} - - -def identify_model_guide(args): - if args.mode == "tiling": - info("Using Mixture Normal model...") - return ( - f"MultiMixtureNormal{'+Acc' if args.scale_by_acc else ''}", - partial( - m.MultiMixtureNormalModel, - scale_by_accessibility=args.scale_by_acc, - use_bcmatch=(not args.ignore_bcmatch,), - ), - partial( - m.MultiMixtureNormalGuide, - scale_by_accessibility=args.scale_by_acc, - fit_noise=~args.dont_fit_noise, - ), - ) - if args.uniform_edit: - if args.guide_activity_col is not None: - raise ValueError( - "Can't use the guide activity column while constraining uniform edit." - ) - info("Using Normal model...") - return ( - "Normal", - partial(m.NormalModel, use_bcmatch=(not args.ignore_bcmatch)), - m.NormalGuide, - ) - elif args.const_pi: - if args.guide_activity_col is not None: - raise ValueError( - "--guide-activity-col to be used as constant pi is not provided." - ) - info("Using Mixture Normal model with constant weight ...") - return ( - "MixtureNormalConstPi", - partial(m.MixtureNormalConstPiModel, use_bcmatch=(not args.ignore_bcmatch)), - m.MixtureNormalGuide, - ) - else: - info( - f"Using Mixture Normal model {'with accessibility normalization' if args.scale_by_acc else ''}..." - ) - return ( - f"{'_' if args.dont_fit_noise else ''}MixtureNormal{'+Acc' if args.scale_by_acc else ''}", - partial( - m.MixtureNormalModel, - scale_by_accessibility=args.scale_by_acc, - use_bcmatch=(not args.ignore_bcmatch,), - ), - partial( - m.MixtureNormalGuide, - scale_by_accessibility=args.scale_by_acc, - fit_noise=(not args.dont_fit_noise), - ), - ) - def main(args, bdata): if args.cuda: @@ -128,32 +57,10 @@ def main(args, bdata): ) os.makedirs(prefix, exist_ok=True) model_label, model, guide = identify_model_guide(args) - guide_index = bdata.guides.index info("Done loading data. Preprocessing...") - bdata.samples[args.replicate_col] = bdata.samples[args.replicate_col].astype( - "category" - ) - bdata.guides = bdata.guides.loc[:, ~bdata.guides.columns.duplicated()].copy() - if args.mode == "variant": - if bdata.guides[args.target_col].isnull().any(): - raise ValueError( - f"Some target column (bdata.guides[{args.target_col}]) value is null. Check your input file." - ) - bdata = bdata[bdata.guides[args.target_col].argsort(), :] - if args.mode == "variant": - n_no_support_targets, bdata = filter_no_info_target( - bdata, - condit_col=args.condition_col, - control_condition=args.control_condition_label, - target_col=args.target_col, - write_no_support_targets=True, - no_support_target_write_path=f"{prefix}/no_support_targets.csv", - ) - if n_no_support_targets > 0: - warn( - f"Ignoring {n_no_support_targets} targets with 0 gRNA counts across all non-control samples. Ignored targets are written in {prefix}/no_support_targets.csv." - ) - ndata = DATACLASS_DICT[model_label]( + bdata = prepare_bdata(bdata, args, warn, prefix) + guide_index = bdata.guides.index.copy() + ndata = DATACLASS_DICT[args.selection][model_label]( screen=bdata, device=args.device, repguide_mask=args.repguide_mask, @@ -162,7 +69,9 @@ def main(args, bdata): accessibility_bw_path=args.acc_bw_path, use_const_pi=args.const_pi, condition_column=args.condition_col, + time_column=args.time_col, control_condition=args.control_condition_label, + control_can_be_selected=args.include_control_condition_for_inference, allele_df_key=args.allele_df_key, control_guide_tag=args.control_guide_tag, target_col=args.target_col, @@ -171,12 +80,12 @@ def main(args, bdata): use_bcmatch=(not args.ignore_bcmatch), ) adj_negctrl_idx = None - if args.mode == "variant": + if args.library_design == "variant": if not args.uniform_edit: - if "edit_rate" not in bdata.guides.columns: - bdata.get_edit_from_allele() - bdata.get_edit_mat_from_uns(rel_pos_is_reporter=True) - bdata.get_guide_edit_rate() + if "edit_rate" not in ndata.screen.guides.columns: + ndata.screen.get_edit_from_allele() + ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True) + ndata.screen.get_guide_edit_rate() target_info_df = _get_guide_target_info( ndata.screen, args, cols_include=[args.negctrl_col] ) @@ -209,7 +118,9 @@ def main(args, bdata): with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle: param_history_dict = pkl.load(handle) else: - param_history_dict = deepcopy(run_inference(model, guide, ndata)) + param_history_dict = deepcopy( + run_inference(model, guide, ndata, num_steps=args.n_iter) + ) if args.fit_negctrl: negctrl_model = m.ControlNormalModel negctrl_guide = m.ControlNormalGuide @@ -249,6 +160,7 @@ def main(args, bdata): else None, adjust_confidence_by_negative_control=args.adjust_confidence_by_negative_control, adjust_confidence_negatives=adj_negctrl_idx, + sd_is_fitted=(args.selection == "sorting"), ) info("Done!") From 84a03d0b96ffb277489885a06e92ec7e67a9f27c Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Sat, 25 Nov 2023 14:55:28 -0500 Subject: [PATCH 02/13] allow experimental condition sample covaraintes unrelated to selection --- README.md | 2 +- bean/framework/ReporterScreen.py | 18 +++++- bean/preprocessing/data_class.py | 5 +- bean/qc/guide_qc.py | 26 ++++++--- bean/qc/utils.py | 72 +++++++++++++++++------ bin/bean-qc | 1 + notebooks/sample_quality_report.ipynb | 83 +++++++++++++++++++++------ setup.py | 4 +- 8 files changed, 161 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index a18ad4a..5069fa7 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ File should contain following columns with header. * `R1_filepath`: Path to read 1 `.fastq[.gz]` file * `R2_filepath`: Path to read 1 `.fastq[.gz]` file * `sample_id`: ID of sequencing sample -* `rep [Optional]`: Replicate # of this sample +* `rep [Optional]`: Replicate # of this sample (Should NOT contain `.`) * `bin [Optional]`: Name of the sorting bin * `upper_quantile [Optional]`: FACS sorting upper quantile * `lower_quantile [Optional]`: FACS sorting lower quantile diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 05eeb3d..5d012c6 100644 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -323,8 +323,20 @@ def __getitem__(self, index): new_uns = deepcopy(self.uns) for k, df in adata.uns.items(): if k.startswith("repguide_mask"): - new_uns[k] = df.loc[guides_include, adata.var.rep.unique()] + if "sample_covariates" in adata.uns: + adata.var["_rc"] = adata.var[ + ["rep"] + adata.uns["sample_covariates"] + ].values.tolist() + adata.var["_rc"] = adata.var["_rc"].map( + lambda slist: ".".join(slist) + ) + new_uns[k] = df.loc[guides_include, adata.var._rc.unique()] + adata.var.pop("_rc") + else: + new_uns[k] = df.loc[guides_include, adata.var.rep.unique()] if not isinstance(df, pd.DataFrame): + if k == "sample_covariates": + new_uns[k] = df continue if "guide" in df.columns: if "allele" in df.columns: @@ -892,7 +904,7 @@ def concat(screens: Collection[ReporterScreen], *args, axis=1, **kwargs): if axis == 0: for k in keys: - if k in ["target_base_change", "tiling"]: + if k in ["target_base_change", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: @@ -902,7 +914,7 @@ def concat(screens: Collection[ReporterScreen], *args, axis=1, **kwargs): if axis == 1: # If combining multiple samples, edit/allele tables should be merged. for k in keys: - if k in ["target_base_change", "tiling"]: + if k in ["target_base_change", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index e3c518d..0a16039 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -2,7 +2,7 @@ import abc import logging from dataclasses import dataclass -from typing import Dict, Tuple +from typing import Dict, Tuple, List from xmlrpc.client import Boolean from copy import deepcopy import torch @@ -40,6 +40,7 @@ def __init__( sample_mask_column: str = None, shrink_alpha: bool = False, condition_column: str = "sort", + sample_covariate_column: List[str] = [], control_condition: str = "bulk", accessibility_col: str = None, accessibility_bw_path: str = None, @@ -51,6 +52,7 @@ def __init__( ): """ Args + condition_column: By default, a single condition column, but you can optionally inlcude sample covariate column control_can_be_selected: If True, screen.samples[condition_column] == control_condition can also be included in effect size inference if its condition column is not NA (Currently only suppoted for prolifertion screens). """ # TODO: remove replicate with too small number of (ex. only 1) sorting bin @@ -821,6 +823,7 @@ def __init__( sample_mask_column: str = None, shrink_alpha: bool = False, condition_column: str = "sort", + sample_covariate_column: List[str] = [], control_condition: str = "bulk", lower_quantile_column: str = "lower_quantile", upper_quantile_column: str = "upper_quantile", diff --git a/bean/qc/guide_qc.py b/bean/qc/guide_qc.py index 309a775..1bbf379 100644 --- a/bean/qc/guide_qc.py +++ b/bean/qc/guide_qc.py @@ -22,15 +22,27 @@ def get_outlier_guides_and_mask( abs_RPM_thres: RPM threshold value that will be used to define outlier guides. """ outlier_guides = get_outlier_guides(bdata, condit_col, mad_z_thres, abs_RPM_thres) - outlier_guides[replicate_col] = bdata.samples.loc[ - outlier_guides["sample"], replicate_col - ].values - mask = pd.DataFrame( - index=bdata.guides.index, columns=bdata.samples[replicate_col].unique() - ).fillna(1) + if not isinstance(replicate_col, str): + outlier_guides["_rc"] = bdata.samples.loc[ + outlier_guides["sample"], replicate_col + ].values.tolist() + outlier_guides["_rc"] = outlier_guides["_rc"].map(lambda slist: ".".join(slist)) + else: + outlier_guides[replicate_col] = bdata.samples.loc[ + outlier_guides["sample"], replicate_col + ].values + if isinstance(replicate_col, str): + reps = bdata.samples[replicate_col].unique() + else: + reps = bdata.samples[replicate_col].drop_duplicates().to_records(index=False) + reps = [".".join(slist) for slist in reps] + mask = pd.DataFrame(index=bdata.guides.index, columns=reps).fillna(1) print(outlier_guides) for _, row in outlier_guides.iterrows(): - mask.loc[row["name"], row[replicate_col]] = 0 + mask.loc[ + row["name"], row[replicate_col if isinstance(replicate_col, str) else "_rc"] + ] = 0 + return outlier_guides, mask diff --git a/bean/qc/utils.py b/bean/qc/utils.py index 48945d0..f584783 100644 --- a/bean/qc/utils.py +++ b/bean/qc/utils.py @@ -1,4 +1,5 @@ import distutils +from typing import Union, List import numpy as np import pandas as pd from copy import deepcopy @@ -52,12 +53,23 @@ def parse_args(): type=str, default="rep", ) + parser.add_argument( + "--sample-covariates", + help="Comma-separated list of column names in `bdata.samples` that describes non-selective experimental condition. (drug treatment, etc.)", + type=str, + default=None, + ) parser.add_argument( "--condition-label", help="Label of column in `bdata.samples` that describes experimental condition. (sorting bin, time, etc.)", type=str, default="bin", ) + parser.add_argument( + "--no-editing", + help="Ignore QC about editing. Can be used for QC of other editing modalities.", + action="store_true", + ) parser.add_argument( "--target-pos-col", help="Target position column in `bdata.guides` specifying target edit position in reporter", @@ -143,21 +155,30 @@ def check_args(args): ) args.lfc_cond1 = lfc_conds[0] args.lfc_cond2 = lfc_conds[1] + if args.sample_covariates is not None: + if "," in args.sample_covariates: + args.sample_covariates = args.sample_covariates.split(",") + args.replicate_label = [args.replicate_label] + args.sample_covariates + else: + args.replicate_label = [args.replicate_label, args.sample_covariates] + if args.no_editing: + args.base_edit_data = False + else: + args.base_edit_data = True return args -def _add_dummy_sample(bdata, rep, cond, condition_label: str, replicate_label: str): +def _add_dummy_sample( + bdata, rep, cond, condition_label: str, replicate_label: Union[str, List[str]] +): sample_id = f"{rep}_{cond}" cond_df = deepcopy(bdata.samples) - cond_df[replicate_label] = np.nan - cond_df = cond_df.drop_duplicates() + # cond_df = cond_df.drop_duplicates() cond_row = cond_df.loc[cond_df[condition_label] == cond, :] - if not len(cond_row) == 1: - raise ValueError( - f"Non-unique condition specification in ReporterScreen.samples: {cond_row}" - ) + if len(cond_row) != 1: + cond_row = cond_row.iloc[[0], :] cond_row.index = [sample_id] - cond_row.loc[:, replicate_label] = rep + cond_row[replicate_label] = rep dummy_sample_bdata = ReporterScreen( X=np.zeros((bdata.n_obs, 1)), X_bcmatch=np.zeros((bdata.n_obs, 1)), @@ -175,27 +196,40 @@ def _add_dummy_sample(bdata, rep, cond, condition_label: str, replicate_label: s return bdata -def fill_in_missing_samples(bdata, condition_label: str, replicate_label: str): +def fill_in_missing_samples( + bdata, condition_label: str, replicate_label: Union[str, List[str]] +): """If not all condition exists for every replicate in bdata, fill in fake sample""" added_dummy = False - for rep in bdata.samples[replicate_label].unique(): + if isinstance(replicate_label, str): + rep_list = bdata.samples[replicate_label].unique() + else: + rep_list = ( + bdata.samples[replicate_label].drop_duplicates().to_records(index=False) + ) + # print(rep_list) + for rep in rep_list: for cond in bdata.samples[condition_label].unique(): + if isinstance(replicate_label, str): + rep_samples = bdata.samples[replicate_label] == rep + else: + rep = list(rep) + rep_samples = (bdata.samples[replicate_label] == rep).all(axis=1) if ( - len( - np.where( - (bdata.samples[replicate_label] == rep) - & (bdata.samples[condition_label] == cond) - )[0] - ) + len(np.where(rep_samples & (bdata.samples[condition_label] == cond))[0]) != 1 ): + print(f"Adding dummy samples for {rep}, {cond}") bdata = _add_dummy_sample( bdata, rep, cond, condition_label, replicate_label ) if not added_dummy: added_dummy = True if added_dummy: - bdata = bdata[ - :, bdata.samples.sort_values([replicate_label, condition_label]).index - ] + if isinstance(replicate_label, str): + sort_labels = [replicate_label, condition_label] + else: + sort_labels = replicate_label + [condition_label] + bdata = bdata[:, bdata.samples.sort_values(sort_labels).index] + return bdata diff --git a/bin/bean-qc b/bin/bean-qc index c3eb72e..c83d6f2 100644 --- a/bin/bean-qc +++ b/bin/bean-qc @@ -34,6 +34,7 @@ def main(): ctrl_cond=args.ctrl_cond, exp_id=args.out_report_prefix, recalculate_edits=args.recalculate_edits, + base_edit_data=args.base_edit_data, ), kernel_name="bean_python3", ) diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index 70b7c44..205d3a3 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -56,7 +56,8 @@ "comp_cond2 = \"bot\"\n", "ctrl_cond = \"bulk\"\n", "recalculate_edits = False\n", - "tiling = None" + "tiling = None\n", + "base_edit_data = True" ] }, { @@ -75,7 +76,9 @@ "outputs": [], "source": [ "if tiling is not None:\n", - " bdata.uns['tiling'] = tiling" + " bdata.uns['tiling'] = tiling\n", + "if not isinstance(replicate_label, str):\n", + " bdata.uns['sample_covariates'] = replicate_label[1:]" ] }, { @@ -208,12 +211,7 @@ "outputs": [], "source": [ "selected_guides = bdata.guides[posctrl_col] == posctrl_val if posctrl_col else ~bdata.guides.index.isnull()\n", - "ax=pt.qc.plot_lfc_correlation(bdata, selected_guides, method=\"Spearman\", cond1=comp_cond1, cond2=comp_cond2, rep_col=replicate_label, compare_col=condition_label, figsize=(10,10))\n", - "\n", - "ax.set_title(\"top/bot LFC correlation, Spearman\")\n", - "plt.yticks(rotation=0) \n", - "plt.xticks(rotation=90) \n", - "plt.show()" + "print(f\"Calculating LFC correlation of {sum(selected_guides)} {'positive control' if posctrl_col else 'all'} guides.\")" ] }, { @@ -221,7 +219,23 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "ax = pt.qc.plot_lfc_correlation(\n", + " bdata,\n", + " selected_guides,\n", + " method=\"Spearman\",\n", + " cond1=comp_cond1,\n", + " cond2=comp_cond2,\n", + " rep_col=replicate_label,\n", + " compare_col=condition_label,\n", + " figsize=(10, 10),\n", + ")\n", + "\n", + "ax.set_title(\"top/bot LFC correlation, Spearman\")\n", + "plt.yticks(rotation=0)\n", + "plt.xticks(rotation=90)\n", + "plt.show()" + ] }, { "cell_type": "code", @@ -243,7 +257,12 @@ "metadata": {}, "outputs": [], "source": [ - "if recalculate_edits or \"edits\" not in bdata.layers.keys() or bdata.layers['edits'].max() == 0:\n", + "if \"target_base_change\" not in bdata.uns or not base_edit_data:\n", + " bdata.uns[\"target_base_change\"] = \"\"\n", + " base_edit_data = False\n", + " print(\"Not a base editing data or target base change not provided. Passing editing-related QC\")\n", + " edit_rate_threshold = -0.1\n", + "elif recalculate_edits or \"edits\" not in bdata.layers.keys() or bdata.layers['edits'].max() == 0:\n", " if 'allele_counts' in bdata.uns.keys():\n", " bdata.uns['allele_counts'] = bdata.uns['allele_counts'].loc[bdata.uns['allele_counts'].allele.map(str) != \"\"]\n", " bdata.get_edit_from_allele()\n", @@ -261,12 +280,22 @@ "metadata": {}, "outputs": [], "source": [ - "if \"edits\" in bdata.layers.keys():\n", + "if \"target_base_change\" not in bdata.uns or not base_edit_data:\n", + " print(\n", + " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", + " )\n", + "elif \"edits\" in bdata.layers.keys():\n", + "\n", " bdata.get_guide_edit_rate(\n", + "\n", " editable_base_start=edit_quantification_start_pos,\n", + "\n", " editable_base_end=edit_quantification_end_pos,\n", + "\n", " unsorted_condition_label=ctrl_cond,\n", + "\n", " )\n", + "\n", " be.qc.plot_guide_edit_rates(bdata)" ] }, @@ -276,11 +305,19 @@ "metadata": {}, "outputs": [], "source": [ - "if \"edits\" in bdata.layers.keys():\n", + "if \"target_base_change\" not in bdata.uns or not base_edit_data:\n", + " print(\n", + " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", + " )\n", + "elif \"edits\" in bdata.layers.keys():\n", + "\n", " bdata.get_edit_rate(\n", - " editable_base_start = edit_quantification_start_pos, \n", - " editable_base_end=edit_quantification_end_pos\n", + " editable_base_start=edit_quantification_start_pos,\n", + "\n", + " editable_base_end=edit_quantification_end_pos,\n", + "\n", " )\n", + "\n", " be.qc.plot_sample_edit_rates(bdata)" ] }, @@ -329,12 +366,24 @@ "outputs": [], "source": [ "# leave replicate with more than 1 sorting bin data\n", - "rep_n_samples = bdata_filtered.samples.groupby(replicate_label)['mask'].sum()\n", + "rep_n_samples = bdata_filtered.samples.groupby(replicate_label)[\"mask\"].sum()\n", "print(rep_n_samples)\n", "rep_has_too_small_sample = rep_n_samples.loc[rep_n_samples < 2].index.tolist()\n", "rep_has_too_small_sample\n", - "print(f\"Excluding reps {rep_has_too_small_sample} that has less than 2 samples per replicate.\")\n", - "bdata_filtered = bdata_filtered[:, ~bdata_filtered.samples[replicate_label].isin(rep_has_too_small_sample)]" + "print(\n", + " f\"Excluding reps {rep_has_too_small_sample} that has less than 2 samples per replicate.\"\n", + ")\n", + "if isinstance(replicate_label, str):\n", + " samples_include = ~bdata_filtered.samples[replicate_label].isin(\n", + " rep_has_too_small_sample\n", + " )\n", + "else:\n", + " bdata_filtered.samples[\"_rc\"] = bdata_filtered.samples[\n", + " replicate_label\n", + " ].values.tolist()\n", + " samples_include = ~bdata_filtered.samples[\"_rc\"].isin(rep_has_too_small_sample)\n", + "bdata_filtered = bdata_filtered[:, samples_include]\n", + "bdata_filtered.samples.pop(\"_rc\")" ] }, { diff --git a/setup.py b/setup.py index 4b2157f..f41178c 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="crispr-bean", - version="0.2.9", + version="0.3.0", python_requires=">=3.8.0", author="Jayoung Ryu", author_email="jayoung_ryu@g.harvard.edu", @@ -36,7 +36,7 @@ "numpy", "pandas", "scipy", - "perturb-tools>=0.2.8", + "perturb-tools>=0.3.0", "matplotlib", "seaborn>=0.13.0", "tqdm", From 5656d92dfa18571846cc782ef945be39cf37090a Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Sat, 25 Nov 2023 14:58:33 -0500 Subject: [PATCH 03/13] allow condition & non-BE data --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 5069fa7..2504b2e 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,8 @@ Above command produces * `--tiling` (default: `None`): If set as `True` or `False`, it sets the screen object to be tiling (`True`) or variant (`False`)-targeting screen when calculating editing rate. * `--replicate-label` (default: `"rep"`): Label of column in `bdata.samples` that describes replicate ID. * `--condition-label` (default: `"bin"`)": Label of column in `bdata.samples` that describes experimental condition. (sorting bin, time, etc.). +* `--sample-covariates` (default: `None`): Comma-separated list of column names in `bdata.samples` that describes non-selective experimental condition (drug treatment, etc.). The values in the `bdata.samples` should NOT contain `.`. +* `--no-editing` (default: `False`): Ignore QC about editing. Can be used for QC of other editing modalities. * `--target-pos-col` (default: `"target_pos"`): Target position column in `bdata.guides` specifying target edit position in reporter. * `--rel-pos-is-reporter` (default: `False`): Specifies whether `edit_start_pos` and `edit_end_pos` are relative to reporter position. If `False`, those are relative to spacer position. * `--edit-start-pos` (default: `2`): Edit start position to quantify editing rate on, 0-based inclusive. From ee96e71b0a34cb1236479a0406e82f0aad86dd71 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Sat, 25 Nov 2023 17:51:15 -0500 Subject: [PATCH 04/13] Initial run with condition --- bean/framework/ReporterScreen.py | 6 ++- bean/model/model.py | 73 +++++++++++++++++++-------- bean/model/utils.py | 28 ++++++---- bean/preprocessing/data_class.py | 43 ++++++++++++++-- bean/preprocessing/utils.py | 8 +-- bin/bean-run | 20 ++++++-- notebooks/sample_quality_report.ipynb | 17 ++++--- 7 files changed, 144 insertions(+), 51 deletions(-) diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 5d012c6..37631d5 100644 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -99,6 +99,8 @@ def __init__( self.layers["X_bcmatch"] = X_bcmatch for k, df in self.uns.items(): if not isinstance(df, pd.DataFrame): + if k == "sample_covariates" and not isinstance(df, list): + self.uns[k] = df.tolist() continue if "guide" in df.columns and len(df) > 0: if ( @@ -325,13 +327,13 @@ def __getitem__(self, index): if k.startswith("repguide_mask"): if "sample_covariates" in adata.uns: adata.var["_rc"] = adata.var[ - ["rep"] + adata.uns["sample_covariates"] + ["rep"] + list(adata.uns["sample_covariates"]) ].values.tolist() adata.var["_rc"] = adata.var["_rc"].map( lambda slist: ".".join(slist) ) new_uns[k] = df.loc[guides_include, adata.var._rc.unique()] - adata.var.pop("_rc") + #adata.var.pop("_rc") else: new_uns[k] = df.loc[guides_include, adata.var.rep.unique()] if not isinstance(df, pd.DataFrame): diff --git a/bean/model/model.py b/bean/model/model.py index c6be6e9..91514f1 100644 --- a/bean/model/model.py +++ b/bean/model/model.py @@ -45,26 +45,51 @@ def NormalModel( sd = sd_alleles sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) assert sd.shape == (data.n_guides, 1) - + if data.sample_covariates is not None: + with pyro.plate("cov_place", data.n_sample_covariates): + mu_cov = pyro.sample("mu_cov", dist.Normal(0, 1)) + assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape with replicate_plate: with bin_plate as b: uq = data.upper_bounds[b] lq = data.lower_bounds[b] assert uq.shape == lq.shape == (data.n_condits,) - # with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)): with guide_plate: + mu = mu.unsqueeze(0).unsqueeze(0).expand( + (data.n_reps, data.n_condits, -1, -1) + ) + (data.rep_by_cov * mu_cov)[:, 0].unsqueeze(-1).unsqueeze( + -1 + ).unsqueeze( + -1 + ).expand( + (-1, data.n_condits, data.n_guides, 1) + ) + sd = torch.sqrt( + ( + sd.unsqueeze(0) + .unsqueeze(0) + .expand((data.n_reps, data.n_condits, -1, -1)) + ) + ) alleles_p_bin = get_std_normal_prob( - uq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)), - lq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)), - mu.unsqueeze(0).expand((data.n_condits, -1, -1)), - sd.unsqueeze(0).expand((data.n_condits, -1, -1)), + uq.unsqueeze(0) + .unsqueeze(-1) + .unsqueeze(-1) + .expand((data.n_reps, -1, data.n_guides, 1)), + lq.unsqueeze(0) + .unsqueeze(-1) + .unsqueeze(-1) + .expand((data.n_reps, -1, data.n_guides, 1)), + mu, + sd, ) - assert alleles_p_bin.shape == (data.n_condits, data.n_guides, 1) - - expected_allele_p = alleles_p_bin.unsqueeze(0).expand( - data.n_reps, -1, -1, -1 - ) - expected_guide_p = expected_allele_p.sum(axis=-1) + assert alleles_p_bin.shape == ( + data.n_reps, + data.n_condits, + data.n_guides, + 1, + ) + expected_guide_p = alleles_p_bin.sum(axis=-1) assert expected_guide_p.shape == ( data.n_reps, data.n_condits, @@ -158,14 +183,10 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): with pyro.plate("guide_plate3", data.n_guides, dim=-1): a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0) - assert ( - data.X.shape - == data.X_bcmatch.shape - == ( - data.n_reps, - data.n_condits, - data.n_guides, - ) + assert data.X.shape == ( + data.n_reps, + data.n_condits, + data.n_guides, ) with poutine.mask( mask=torch.logical_and( @@ -490,6 +511,18 @@ def NormalGuide(data): constraint=constraints.positive, ) pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale)) + if data.sample_covariates is not None: + with pyro.plate("cov_place", data.n_sample_covariates): + mu_cov_loc = pyro.param( + "mu_cov_loc", torch.zeros((data.n_sample_covariates,)) + ) + mu_cov_scale = pyro.param( + "mu_cov_scale", + torch.ones((data.n_sample_covariates,)), + constraint=constraints.positive, + ) + mu_cov = pyro.sample("mu_cov", dist.Normal(mu_cov_loc, mu_cov_scale)) + assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape def MixtureNormalGuide( diff --git a/bean/model/utils.py b/bean/model/utils.py index 7767232..7dfba0a 100644 --- a/bean/model/utils.py +++ b/bean/model/utils.py @@ -8,17 +8,23 @@ def get_alpha( expected_guide_p, size_factor, sample_mask, a0, epsilon=1e-5, normalize_by_a0=True ): - p = ( - expected_guide_p.permute(0, 2, 1) * size_factor[:, None, :] - ) # (n_reps, n_guides, n_bins) - if normalize_by_a0: - a = ( - (p + epsilon / p.shape[-1]) - / (p.sum(axis=-1)[:, :, None] + epsilon) - * a0[None, :, None] - ) - a = (a * sample_mask[:, None, :]).clamp(min=epsilon) - return a + try: + p = ( + expected_guide_p.permute(0, 2, 1) * size_factor[:, None, :] + ) # (n_reps, n_guides, n_bins) + + if normalize_by_a0: + a = ( + (p + epsilon / p.shape[-1]) + / (p.sum(axis=-1)[:, :, None] + epsilon) + * a0[None, :, None] + ) + a = (a * sample_mask[:, None, :]).clamp(min=epsilon) + return a + except: + print(size_factor.shape) + print(expected_guide_p.shape) + print(a0.shape) a = (p * sample_mask[:, None, :]).clamp(min=epsilon) return a diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 0a16039..9f01947 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -60,17 +60,34 @@ def __init__( self.device = device screen.samples["size_factor"] = self.get_size_factor(screen.X) if not ( - "rep" in screen.samples.columns + replicate_column in screen.samples.columns and condition_column in screen.samples.columns ): - screen.samples["rep"], screen.samples[condition_column] = zip( + screen.samples[replicate_column], screen.samples[condition_column] = zip( *screen.samples.index.map(lambda s: s.rsplit("_", 1)) ) if condition_column not in screen.samples.columns: screen.samples[condition_column] = screen.samples["index"].map( lambda s: s.split("_")[-1] ) - + if "sample_covariates" in screen.uns: + self.sample_covariates = screen.uns["sample_covariates"] + self.n_sample_covariates = len(self.sample_covariates) + screen.samples["_rc"] = screen.samples[ + [replicate_column] + self.sample_covariates + ].values.tolist() + screen.samples["_rc"] = screen.samples["_rc"].map( + lambda slist: ".".join(slist) + ) + self.rep_by_cov = torch.as_tensor( + ( + screen.samples[["_rc"] + self.sample_covariates] + .drop_duplicates() + .set_index("_rc") + .values.astype(int) + ) + ) + replicate_column = "_rc" self.screen = screen if not control_can_be_selected: self.screen_selected = screen[ @@ -146,7 +163,7 @@ def _post_init( ).all() assert ( self.screen_selected.uns[self.repguide_mask].columns - == self.screen_selected.samples.rep.unique() + == self.screen_selected.samples[self.replicate_column].unique() ).all() self.repguide_mask = ( torch.as_tensor(self.screen_selected.uns[self.repguide_mask].values.T) @@ -182,6 +199,7 @@ def __getitem__(self, guide_idx): ndata.X_masked = ndata.X_masked[:, :, guide_idx] ndata.X_control = ndata.X_control[:, :, guide_idx] ndata.repguide_mask = ndata.repguide_mask[:, guide_idx] + ndata.a0 = ndata.a0[guide_idx] return ndata def transform_data(self, X, n_bins=None): @@ -905,9 +923,20 @@ def _pre_init( self.screen.samples.loc[ self.screen_selected.samples.index, f"{self.condition_column}_id" ] = self.screen_selected.samples[f"{self.condition_column}_id"] + print(self.screen.samples.columns) self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.condition_column ) + print(self.screen.samples.columns) + if self.sample_covariates is not None: + self.rep_by_cov = torch.as_tensor( + ( + self.screen.samples[["_rc"] + self.sample_covariates] + .drop_duplicates() + .set_index("_rc") + .values.astype(int) + ) + ) self.screen_selected = _assign_rep_ids_and_sort( self.screen_selected, self.replicate_column, self.condition_column ) @@ -986,8 +1015,12 @@ def _post_init( self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.time_column ) + if self.sample_covariates is not None: + self.rep_by_cov = self.screen.samples.groupby(self.replicate_column)[ + self.sample_covariates + ].values self.screen_selected = _assign_rep_ids_and_sort( - self.screen_selected, self.replicate_column, self.time_column + self.screen_selected, self.replicate_column, self.condition_column ) self.screen_control = _assign_rep_ids_and_sort( self.screen_control, diff --git a/bean/preprocessing/utils.py b/bean/preprocessing/utils.py index 8596782..44c6431 100644 --- a/bean/preprocessing/utils.py +++ b/bean/preprocessing/utils.py @@ -219,10 +219,10 @@ def _assign_rep_ids_and_sort( sort_key = f"{rep_col}_id" else: sort_key = [f"{rep_col}_id", f"{condition_column}_id"] - screen = screen[ - :, - screen.samples.sort_values(sort_key).index, - ] + screen = screen[ + :, + screen.samples.sort_values(sort_key).index, + ] return screen diff --git a/bin/bean-run b/bin/bean-run index 669da0b..04d364e 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -2,6 +2,8 @@ import os import sys import logging +import warnings +from functools import partial from copy import deepcopy import numpy as np import pandas as pd @@ -43,6 +45,11 @@ warn = logging.warning debug = logging.debug info = logging.info pyro.set_rng_seed(101) +warnings.filterwarnings( + "ignore", + category=FutureWarning, + message=r".*is_categorical_dtype is deprecated and will be removed in a future version.*", +) def main(args, bdata): @@ -127,8 +134,15 @@ def main(args, bdata): run_inference(model, guide, ndata, num_steps=args.n_iter) ) if args.fit_negctrl: - negctrl_model = m.ControlNormalModel - negctrl_guide = m.ControlNormalGuide + negctrl_model = partial( + m.ControlNormalModel, + use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), + ) + print((not args.ignore_bcmatch and "X_bcmatch" in bdata.layers)) + negctrl_guide = partial( + m.ControlNormalGuide, + use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), + ) negctrl_idx = np.where( guide_info_df[args.negctrl_col].map(lambda s: s.lower()) == args.negctrl_col_value.lower() @@ -137,7 +151,7 @@ def main(args, bdata): print(negctrl_idx.shape) ndata_negctrl = ndata[negctrl_idx] param_history_dict["negctrl"] = run_inference( - negctrl_model, negctrl_guide, ndata_negctrl + negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter ) outfile_path = ( diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index 205d3a3..bf632b3 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -76,9 +76,10 @@ "outputs": [], "source": [ "if tiling is not None:\n", - " bdata.uns['tiling'] = tiling\n", + " bdata.uns[\"tiling\"] = tiling\n", "if not isinstance(replicate_label, str):\n", - " bdata.uns['sample_covariates'] = replicate_label[1:]" + " bdata.uns[\"sample_covariates\"] = replicate_label[1:]\n", + "bdata.samples[replicate_label] = bdata.samples[replicate_label].astype(str)" ] }, { @@ -352,11 +353,15 @@ "metadata": {}, "outputs": [], "source": [ - "bdata.samples['mask'] = 1\n", - "bdata.samples.loc[bdata.samples.median_corr_X < corr_X_thres, 'mask'] = 0\n", + "bdata.samples[\"mask\"] = 1\n", + "bdata.samples.loc[\n", + " bdata.samples.median_corr_X.isnull() | (bdata.samples.median_corr_X < corr_X_thres), \"mask\"\n", + "] = 0\n", "if \"median_editing_rate\" in bdata.samples.columns.tolist():\n", - " bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, 'mask'] = 0\n", - "bdata_filtered = bdata[:, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres]" + " bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, \"mask\"] = 0\n", + "bdata_filtered = bdata[\n", + " :, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres\n", + "]" ] }, { From 6b1682a85aded0a66585c6c5077bae26c5f7ca79 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Sat, 25 Nov 2023 18:17:23 -0500 Subject: [PATCH 05/13] debug condition runs --- bean/model/readwrite.py | 48 +++++++++++++++++++++++++++++++- bean/preprocessing/data_class.py | 2 -- bin/bean-run | 8 ++++-- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/bean/model/readwrite.py b/bean/model/readwrite.py index b41de39..c1c68cd 100644 --- a/bean/model/readwrite.py +++ b/bean/model/readwrite.py @@ -1,4 +1,4 @@ -from typing import Union, Sequence, Optional +from typing import Union, Sequence, Optional, List import numpy as np import pandas as pd from statistics import NormalDist @@ -58,6 +58,7 @@ def write_result_table( guide_index: Optional[Sequence[str]] = None, guide_acc: Optional[Sequence] = None, sd_is_fitted: bool = True, + sample_covariates: List[str] = None, return_result: bool = False, ) -> Union[pd.DataFrame, None]: """Combine target information and scores to write result table to a csv file or return it.""" @@ -82,6 +83,24 @@ def write_result_table( } if sd_is_fitted: param_dict["sd"] = sd + if sample_covariates is not None: + assert ( + "mu_cov_loc" in param_hist_dict["params"] + and "mu_cov_scale" in param_hist_dict["params"] + ), param_hist_dict["params"].keys() + for i, sample_cov in enumerate(sample_covariates): + param_dict[f"mu_{sample_cov}"] = ( + mu + param_hist_dict["params"]["mu_cov_loc"].detach().cpu().numpy()[i] + ) + param_dict[f"mu_sd_{sample_cov}"] = np.sqrt( + mu_sd**2 + + param_hist_dict["params"]["mu_cov_scale"].detach().cpu().numpy()[i] + ** 2 + ) + param_dict[f"mu_z_{sample_cov}"] = ( + param_dict[f"mu_{sample_cov}"] / param_dict[f"mu_sd_{sample_cov}"] + ) + fit_df = pd.DataFrame(param_dict) fit_df["novl"] = get_novl(fit_df, "mu", "mu_sd") if "negctrl" in param_hist_dict.keys(): @@ -102,6 +121,17 @@ def write_result_table( if sd_is_fitted: fit_df["sd_scaled"] = sd / sd0 fit_df["novl_scaled"] = get_novl(fit_df, "mu_scaled", "mu_sd_scaled") + if sample_covariates is not None: + for i, sample_cov in enumerate(sample_covariates): + fit_df[f"mu_{sample_cov}_scaled"] = ( + fit_df[f"mu_{sample_cov}"] - mu0 + ) / sd0 + fit_df[f"mu_sd_{sample_cov}_scaled"] = ( + fit_df[f"mu_sd_{sample_cov}"] / sd0 + ) + fit_df[f"mu_z_{sample_cov}_scaled"] = ( + fit_df[f"mu_{sample_cov}_scaled"] / fit_df["mu_sd_scaled"] + ) fit_df = pd.concat( [target_info_df.reset_index(), fit_df.reset_index(drop=True)], axis=1 @@ -132,6 +162,22 @@ def write_result_table( else "mu_sd", ) fit_df = add_credible_interval(fit_df, "mu_adj", "mu_sd_adj") + if sample_covariates is not None: + for i, sample_cov in enumerate(sample_covariates): + fit_df = adjust_normal_params_by_control( + fit_df, + std, + suffix=f"_{sample_cov}_adj", + mu_adjusted_col=f"mu_{sample_cov}_scaled" + if "negctrl" in param_hist_dict.keys() + else f"mu_{sample_cov}", + mu_sd_adjusted_col=f"mu_sd_{sample_cov}_scaled" + if "negctrl" in param_hist_dict.keys() + else f"mu_sd_{sample_cov}", + ) + fit_df = add_credible_interval( + fit_df, f"mu_{sample_cov}_adj", f"mu_sd_{sample_cov}_adj" + ) if write_fitted_eff or guide_acc is not None: if "alpha_pi" not in param_hist_dict["params"].keys(): diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 9f01947..505dab7 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -923,11 +923,9 @@ def _pre_init( self.screen.samples.loc[ self.screen_selected.samples.index, f"{self.condition_column}_id" ] = self.screen_selected.samples[f"{self.condition_column}_id"] - print(self.screen.samples.columns) self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.condition_column ) - print(self.screen.samples.columns) if self.sample_covariates is not None: self.rep_by_cov = torch.as_tensor( ( diff --git a/bin/bean-run b/bin/bean-run index 04d364e..6bff913 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -138,7 +138,7 @@ def main(args, bdata): m.ControlNormalModel, use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), ) - print((not args.ignore_bcmatch and "X_bcmatch" in bdata.layers)) + negctrl_guide = partial( m.ControlNormalGuide, use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), @@ -147,8 +147,9 @@ def main(args, bdata): guide_info_df[args.negctrl_col].map(lambda s: s.lower()) == args.negctrl_col_value.lower() )[0] - print(len(negctrl_idx)) - print(negctrl_idx.shape) + info( + f"Using {len(negctrl_idx)} negative control elements to adjust phenotypic effect sizes..." + ) ndata_negctrl = ndata[negctrl_idx] param_history_dict["negctrl"] = run_inference( negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter @@ -180,6 +181,7 @@ def main(args, bdata): adjust_confidence_by_negative_control=args.adjust_confidence_by_negative_control, adjust_confidence_negatives=adj_negctrl_idx, sd_is_fitted=(args.selection == "sorting"), + sample_covariates=ndata.sample_covariates, ) info("Done!") From 34b440439388b70c8da428adc745b27094297e15 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Sat, 25 Nov 2023 20:23:11 -0500 Subject: [PATCH 06/13] debug condition rep --- notebooks/sample_quality_report.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index bf632b3..9b8f03c 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -387,8 +387,8 @@ " replicate_label\n", " ].values.tolist()\n", " samples_include = ~bdata_filtered.samples[\"_rc\"].isin(rep_has_too_small_sample)\n", - "bdata_filtered = bdata_filtered[:, samples_include]\n", - "bdata_filtered.samples.pop(\"_rc\")" + " bdata_filtered.samples.pop(\"_rc\")\n", + "bdata_filtered = bdata_filtered[:, samples_include]" ] }, { From e1caefce745027dcb0365e4a034513a6f6138526 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 28 Nov 2023 10:42:58 -0500 Subject: [PATCH 07/13] clean up binary --- bean/model/run.py | 454 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 454 insertions(+) create mode 100644 bean/model/run.py diff --git a/bean/model/run.py b/bean/model/run.py new file mode 100644 index 0000000..a1619f4 --- /dev/null +++ b/bean/model/run.py @@ -0,0 +1,454 @@ +import os +import sys +import argparse +from tqdm import tqdm +import pickle as pkl +import pandas as pd +import logging +from functools import partial +import pyro +import bean.model.model as sorting_model +import bean.model.survival_model as survival_model + +logging.basicConfig( + level=logging.INFO, + format="%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n", + datefmt="%a, %d %b %Y %H:%M:%S", + stream=sys.stderr, + filemode="w", +) +error = logging.critical +warn = logging.warning +debug = logging.debug +info = logging.info +pyro.set_rng_seed(101) + + +def none_or_str(value): + if value == "None": + return None + return value + + +def parse_args(): + print( + r""" + _ _ + / \ '\ + | \ \ _ _ _ _ _ _ + \ \ | | '_| || | ' \ + `.__|/ |_| \_,_|_||_| + """ + ) + print("bean-run: Run model to identify targeted variants and their impact.") + parser = argparse.ArgumentParser(description="Run model on data.") + parser.add_argument( + "selection", + type=str, + choices=["sorting", "survival"], + help="Screen selection type whether cells are sorted based on continuous phenotype ('sorting') or proliferated based on their viability ('survival').", + ) + parser.add_argument( + "library_design", + type=str, + choices=["variant", "tiling"], + help="Library design type whether to run variant or tiling screen model.\nVariant library design assumes gRNA has specific target variant and bystander edits are ignored. Tiling library design considers all alleles generated by gRNA in reporter.", + ) + parser.add_argument("bdata_path", type=str, help="Path of an ReporterScreen object") + parser.add_argument( + "--rep-pi", + "-r", + action="store_true", + default=False, + help="Fit replicate specific scaling factor. Recommended to set as True if you expect variable editing activity across biological replicates.", + ) + parser.add_argument( + "--uniform-edit", + "-p", + action="store_true", + default=False, + help="Assume uniform editing rate for all guides.", + ) + parser.add_argument( + "--scale-by-acc", + action="store_true", + default=False, + help="Scale guide editing efficiency by the target loci accessibility", + ) + parser.add_argument( + "--acc-bw-path", + type=str, + default=None, + help="Accessibility .bigWig file to be used to assign accessibility of guides.", + ) + parser.add_argument( + "--acc-col", + type=str, + default=None, + help="Column name in bdata.guides that specify raw ATAC-seq signal.", + ) + parser.add_argument( + "--const-pi", + default=False, + action="store_true", + help="Use constant pi provided in --guide-activity-col (instead of fitting from reporter data)", + ) + parser.add_argument( + "--shrink-alpha", + default=False, + action="store_true", + help="Instead of using the trend-fitted alpha values, use estimated alpha values for each gRNA that are shrunk towards the fitted trend.", + ) + parser.add_argument( + "--condition-col", + default="bin", + type=str, + help="Column key in `bdata.samples` that describes experimental condition.", + ) + parser.add_argument( + "--time-col", + default="time", + type=str, + help="Column key in `bdata.samples` that describes time elapsed.", + ) + parser.add_argument( + "--control-condition-label", + default="bulk", + type=str, + help="Value in `bdata.samples[condition_col]` that indicates control experimental condition.", + ) + parser.add_argument( + "--include-control-condition-for-inference", + "-ic", + default=False, + action="store_true", + help="Include control conditions for inference. Currently only supported for survival screens.", + ) + parser.add_argument( + "--replicate-col", + default="rep", + type=str, + help="Column key in `bdata.samples` that describes experimental replicates.", + ) + parser.add_argument( + "--target-col", + default="target", + type=str, + help="Column key in `bdata.guides` that describes the target element of each guide.", + ) + parser.add_argument( + "--guide-activity-col", + "-a", + type=str, + default=None, + help="Column in ReporterScreen.guides DataFrame showing the editing rate estimated via external tools", + ) + parser.add_argument( + "--outdir", + "-o", + default=".", + type=str, + help="Directory to save the run result.", + ) + parser.add_argument( + "--result-suffix", + default="", + type=str, + help="Suffix of the output files", + ) + parser.add_argument( + "--sorting-bin-upper-quantile-col", + "-uq", + help="Column name with upper quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData.var)", + default="upper_quantile", + ) + parser.add_argument( + "--sorting-bin-lower-quantile-col", + "-lq", + help="Column name with lower quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData var)", + default="lower_quantile", + ) + parser.add_argument( + "--alpha-if-overdispersion-fitting-fails", + "-af", + default=None, + type=str, + help="Comma-separated regression coefficient (b0, b1) of log(a0) ~ log(q) that will be used if fitting dispersion on the data fails.", + ) + parser.add_argument("--cuda", action="store_true", default=False, help="run on GPU") + parser.add_argument( + "--sample-mask-col", + type=str, + default=None, + help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.", + ) + parser.add_argument( + "--fit-negctrl", + action="store_true", + default=False, + help="Fit the shared negative control distribution to normalize the fitted parameters", + ) + parser.add_argument( + "--negctrl-col", + type=str, + default="target_group", + help="Column in bdata.obs specifying if a guide is negative control. If the `bdata.guides[negctrl_col].lower() == negctrl_col_value`, it is treated as negative control guide.", + ) + parser.add_argument( + "--negctrl-col-value", + type=str, + default="negctrl", + help="Column value in bdata.guides specifying if a guide is negative control. If the `bdata.guides[negctrl_col].lower() == negctrl_col_value`, it is treated as negative control guide.", + ) + parser.add_argument( + "--repguide-mask", + type=none_or_str, + default="repguide_mask", + help="n_replicate x n_guide mask to mask the outlier guides. screen.uns[repguide_mask] will be used.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Optionally use GPU if provided valid GPU device name (ex. cuda:0)", + ) + parser.add_argument( + "--ignore-bcmatch", + action="store_true", + default=False, + help="If provided, even if the screen object has .X_bcmatch, ignore the count when fitting.", + ) + parser.add_argument( + "--allele-df-key", + type=str, + default=None, + help="screen.uns[allele_df_key] will be used as the allele count.", + ) + parser.add_argument( + "--splice-site-path", + type=str, + default=None, + help="Path to splicing site", + ) + parser.add_argument( + "--control-guide-tag", + type=none_or_str, + default="CONTROL", + help="If this string is in guide name, treat each guide separately not to mix the position. Used for negative controls.", + ) + parser.add_argument( + "--dont-fit-noise", # TODO: add check args + action="store_true", + ) + parser.add_argument( + "--dont-adjust-confidence-by-negative-control", + action="store_true", + help="Adjust confidence by negative controls. For variant library_design, this uses negative control variants. For tiling library_design, adjusts confidence by synonymous edits.", + ) + parser.add_argument( + "--n-iter", # TODO: add check args + type=int, + default=2000, + help="# of SVI steps taken for inference.", + ) + parser.add_argument( + "--load-existing", # TODO: add check args + action="store_true", + help="Load existing .pkl file if present.", + ) + + return parser.parse_args() + + +def check_args(args, bdata): + args.adjust_confidence_by_negative_control = ( + not args.dont_adjust_confidence_by_negative_control + ) + if args.scale_by_acc: + if args.acc_col is None and args.acc_bw_path is None: + raise ValueError( + "--scale-by-acc not accompanied by --acc-col nor --acc-bw-path to use. Pass either one." + ) + elif args.acc_col is not None and args.acc_bw_path is not None: + warn( + "Both --acc-col and --acc-bw-path is specified. --acc-bw-path is ignored." + ) + args.acc_bw_path = None + if args.outdir is None: + args.outdir = os.path.dirname(args.bdata_path) + if args.library_design == "variant": + pass + elif args.library_design == "tiling": + if args.allele_df_key is None: + raise ValueError( + "--allele-df-key not provided for tiling screen. Feed in the key then allele counts in screen.uns[allele_df_key] will be used." + ) + else: + raise ValueError( + "Invalid library_design provided. Select either 'variant' or 'tiling'." + ) # TODO: change this into discrete modes via argparse + if args.fit_negctrl: + n_negctrl = ( + bdata.guides[args.negctrl_col].map(lambda s: s.lower()) + == args.negctrl_col_value.lower() + ).sum() + if not n_negctrl >= 20: + raise ValueError( + f"Not enough negative control guide in the input data: {n_negctrl}. Check your input arguments." + ) + if args.repguide_mask is not None and args.repguide_mask not in bdata.uns.keys(): + bdata.uns[args.repguide_mask] = pd.DataFrame( + index=bdata.guides.index, columns=bdata.samples[args.replicate_col].unique() + ).fillna(1) + warn( + f"{args.bdata_path} does not have replicate x guide outlier mask. All guides are included in analysis." + ) + if args.sample_mask_col is not None: + if args.sample_mask_col not in bdata.samples.columns.tolist(): + raise ValueError( + f"{args.bdata_path} does not have specified sample mask column {args.sample_mask_col} in .samples" + ) + if args.alpha_if_overdispersion_fitting_fails is not None: + try: + b0, b1 = args.alpha_if_overdispersion_fitting_fails.split(",") + args.popt = (float(b0), float(b1)) + except TypeError as e: + raise e( + f"Input --alpha-if-overdispersion-fitting-fails {args.alpha_if_overdispersion_fitting_fails} is malformatted! Provide [float].[float] format." + ) + else: + args.popt = None + return args, bdata + + +def _get_guide_target_info(bdata, args, cols_include=[]): + guide_info = bdata.guides.copy() + target_info = ( + guide_info[ + [args.target_col] + + [ + col + for col in guide_info.columns + if ( + ( + (col.startswith("target_")) + and len(guide_info[[args.target_col, col]].drop_duplicates()) + == len(guide_info[args.target_col].unique()) + ) + or col in cols_include + ) + and col != args.target_col + ] + ] + .drop_duplicates() + .set_index(args.target_col, drop=True) + ) + target_info["n_guides"] = guide_info.groupby("target").size() + + if "edit_rate" in guide_info.columns.tolist(): + edit_rate_info = ( + guide_info[[args.target_col, "edit_rate"]] + .groupby(args.target_col, sort=False) + .agg({"edit_rate": ["mean", "std"]}) + ) + edit_rate_info.columns = edit_rate_info.columns.get_level_values(1) + edit_rate_info = edit_rate_info.rename( + columns={"mean": "edit_rate_mean", "std": "edit_rate_std"} + ) + target_info = target_info.join(edit_rate_info) + return target_info + + +def run_inference( + model, guide, data, initial_lr=0.01, gamma=0.1, num_steps=2000, autoguide=False +): + pyro.clear_param_store() + lrd = gamma ** (1 / num_steps) + svi = pyro.infer.SVI( + model=model, + guide=guide, + optim=pyro.optim.ClippedAdam({"lr": initial_lr, "lrd": lrd}), + loss=pyro.infer.Trace_ELBO(), + ) + losses = [] + try: + for t in tqdm(range(num_steps)): + loss = svi.step(data) + if t % 100 == 0: + print(f"loss {loss} @ iter {t}") + losses.append(loss) + except ValueError as exc: + error( + "Error occurred during fitting. Saving temporary output at tmp_result.pkl." + ) + with open("tmp_result.pkl", "wb") as handle: + pkl.dump({"param": pyro.get_param_store()}, handle) + + raise ValueError( + f"Fitting halted for command: {' '.join(sys.argv)} with following error: \n {exc}" + ) + return { + "loss": losses, + "params": pyro.get_param_store(), + } + + +def identify_model_guide(args): + if args.selection == "sorting": + m = sorting_model + else: + m = survival_model + if args.library_design == "tiling": + info("Using Mixture Normal model...") + return ( + f"MultiMixtureNormal{'+Acc' if args.scale_by_acc else ''}", + partial( + m.MultiMixtureNormalModel, + scale_by_accessibility=args.scale_by_acc, + use_bcmatch=(not args.ignore_bcmatch,), + ), + partial( + m.MultiMixtureNormalGuide, + scale_by_accessibility=args.scale_by_acc, + fit_noise=~args.dont_fit_noise, + ), + ) + if args.uniform_edit: + if args.guide_activity_col is not None: + raise ValueError( + "Can't use the guide activity column while constraining uniform edit." + ) + info("Using Normal model...") + return ( + "Normal", + partial(m.NormalModel, use_bcmatch=(not args.ignore_bcmatch)), + m.NormalGuide, + ) + elif args.const_pi: + if args.guide_activity_col is not None: + raise ValueError( + "--guide-activity-col to be used as constant pi is not provided." + ) + info("Using Mixture Normal model with constant weight ...") + return ( + "MixtureNormalConstPi", + partial(m.MixtureNormalConstPiModel, use_bcmatch=(not args.ignore_bcmatch)), + m.MixtureNormalGuide, + ) + else: + info( + f"Using Mixture Normal model {'with accessibility normalization' if args.scale_by_acc else ''}..." + ) + return ( + f"{'_' if args.dont_fit_noise else ''}MixtureNormal{'+Acc' if args.scale_by_acc else ''}", + partial( + m.MixtureNormalModel, + scale_by_accessibility=args.scale_by_acc, + use_bcmatch=(not args.ignore_bcmatch,), + ), + partial( + m.MixtureNormalGuide, + scale_by_accessibility=args.scale_by_acc, + fit_noise=(not args.dont_fit_noise), + ), + ) From 27e3faf2f186e464546eacdf2176fd63029624d0 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 28 Nov 2023 10:57:54 -0500 Subject: [PATCH 08/13] allow external popt --- bean/model/model.py | 20 ++++++++++---------- bean/preprocessing/data_class.py | 15 ++++++++------- bean/preprocessing/get_alpha0.py | 10 +++++++--- bin/bean-run | 5 ++++- notebooks/sample_quality_report.ipynb | 27 +++++++++++++-------------- tests/test_run.py | 12 ++++++------ 6 files changed, 48 insertions(+), 41 deletions(-) diff --git a/bean/model/model.py b/bean/model/model.py index 91514f1..96514ce 100644 --- a/bean/model/model.py +++ b/bean/model/model.py @@ -45,7 +45,7 @@ def NormalModel( sd = sd_alleles sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) assert sd.shape == (data.n_guides, 1) - if data.sample_covariates is not None: + if hasattr(data, "sample_covariates"): with pyro.plate("cov_place", data.n_sample_covariates): mu_cov = pyro.sample("mu_cov", dist.Normal(0, 1)) assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape @@ -55,15 +55,15 @@ def NormalModel( lq = data.lower_bounds[b] assert uq.shape == lq.shape == (data.n_condits,) with guide_plate: - mu = mu.unsqueeze(0).unsqueeze(0).expand( - (data.n_reps, data.n_condits, -1, -1) - ) + (data.rep_by_cov * mu_cov)[:, 0].unsqueeze(-1).unsqueeze( - -1 - ).unsqueeze( - -1 - ).expand( - (-1, data.n_condits, data.n_guides, 1) + mu = ( + mu.unsqueeze(0) + .unsqueeze(0) + .expand((data.n_reps, data.n_condits, -1, -1)) ) + if hasattr(data, "sample_covariates"): + mu = mu + (data.rep_by_cov * mu_cov)[:, 0].unsqueeze(-1).unsqueeze( + -1 + ).unsqueeze(-1).expand((-1, data.n_condits, data.n_guides, 1)) sd = torch.sqrt( ( sd.unsqueeze(0) @@ -511,7 +511,7 @@ def NormalGuide(data): constraint=constraints.positive, ) pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale)) - if data.sample_covariates is not None: + if hasattr(data, "sample_covariates"): with pyro.plate("cov_place", data.n_sample_covariates): mu_cov_loc = pyro.param( "mu_cov_loc", torch.zeros((data.n_sample_covariates,)) diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 505dab7..2ba381c 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -2,7 +2,7 @@ import abc import logging from dataclasses import dataclass -from typing import Dict, Tuple, List +from typing import Optional, Dict, Tuple, List from xmlrpc.client import Boolean from copy import deepcopy import torch @@ -46,7 +46,8 @@ def __init__( accessibility_bw_path: str = None, device: str = None, replicate_column: str = "rep", - pi_popt: Tuple[float] = None, + popt: Optional[Tuple[float]] = None, + pi_popt: Optional[Tuple[float]] = None, control_can_be_selected: bool = False, **kwargs, ): @@ -113,10 +114,9 @@ def __init__( self.sample_mask_column = sample_mask_column self.repguide_mask = repguide_mask self.shrink_alpha = shrink_alpha + self.popt = popt - def _post_init( - self, - ): + def _post_init(self): # Assign accessibility info if self.accessibility_col is not None: self.guide_accessibility = torch.as_tensor( @@ -185,6 +185,7 @@ def _post_init( self.size_factor.clone().cpu(), self.sample_mask.cpu(), shrink=self.shrink_alpha, + popt=self.popt, ) fitted_a0 = torch.as_tensor(fitted_a0) a0 = fitted_a0 @@ -926,7 +927,7 @@ def _pre_init( self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.condition_column ) - if self.sample_covariates is not None: + if hasattr(self, "sample_covariates"): self.rep_by_cov = torch.as_tensor( ( self.screen.samples[["_rc"] + self.sample_covariates] @@ -1013,7 +1014,7 @@ def _post_init( self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.time_column ) - if self.sample_covariates is not None: + if hasattr(self, "sample_covariates"): self.rep_by_cov = self.screen.samples.groupby(self.replicate_column)[ self.sample_covariates ].values diff --git a/bean/preprocessing/get_alpha0.py b/bean/preprocessing/get_alpha0.py index 31f3b03..f73f454 100644 --- a/bean/preprocessing/get_alpha0.py +++ b/bean/preprocessing/get_alpha0.py @@ -1,3 +1,4 @@ +from typing import Optional, Tuple import numpy as np import torch from scipy.optimize import curve_fit @@ -71,8 +72,9 @@ def get_fitted_alpha0( sample_size_factors, sample_mask=None, fit_quantile: float = None, - shrink=False, - shrink_prior_var=1.0, + shrink: bool = False, + shrink_prior_var: float = 1.0, + popt: Optional[Tuple[float, float]] = None, ): """Fits sum of concentration of DirichletMultinomial distribution. @@ -80,6 +82,7 @@ def get_fitted_alpha0( fit: if False, return the raw value fit_quantile: if not None, alpha is fitted conservatively with lowest `fit_quantile` guides. + popt: Regression coefficient (b0, b1) of log(a0) ~ log(q) that will be used if fitting dispersion on the data fails """ n_reps, n_condits, n_guides = X.shape if sample_mask is None: @@ -98,7 +101,8 @@ def get_fitted_alpha0( x, y = get_valid_vals(n.log(), a0.log()) if len(y) < 10: - popt = (-1.510, 0.7861) + if popt is None: + popt = (-1.510, 0.7861) print( f"Cannot fit log(a0) ~ log(q): data too sparse! Using pre-fitted values [b0, b1]={popt}" ) diff --git a/bin/bean-run b/bin/bean-run index 6bff913..810bbc5 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -84,6 +84,7 @@ def main(args, bdata): control_guide_tag=args.control_guide_tag, target_col=args.target_col, shrink_alpha=args.shrink_alpha, + popt=args.popt, replicate_col=args.replicate_col, use_bcmatch=(not args.ignore_bcmatch), ) @@ -181,7 +182,9 @@ def main(args, bdata): adjust_confidence_by_negative_control=args.adjust_confidence_by_negative_control, adjust_confidence_negatives=adj_negctrl_idx, sd_is_fitted=(args.selection == "sorting"), - sample_covariates=ndata.sample_covariates, + sample_covariates=ndata.sample_covariates + if hasattr(ndata, "sample_covariates") + else None, ) info("Done!") diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index 9b8f03c..7c18db9 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -286,17 +286,11 @@ " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", " )\n", "elif \"edits\" in bdata.layers.keys():\n", - "\n", " bdata.get_guide_edit_rate(\n", - "\n", " editable_base_start=edit_quantification_start_pos,\n", - "\n", " editable_base_end=edit_quantification_end_pos,\n", - "\n", " unsorted_condition_label=ctrl_cond,\n", - "\n", " )\n", - "\n", " be.qc.plot_guide_edit_rates(bdata)" ] }, @@ -311,14 +305,10 @@ " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", " )\n", "elif \"edits\" in bdata.layers.keys():\n", - "\n", " bdata.get_edit_rate(\n", " editable_base_start=edit_quantification_start_pos,\n", - "\n", " editable_base_end=edit_quantification_end_pos,\n", - "\n", " )\n", - "\n", " be.qc.plot_sample_edit_rates(bdata)" ] }, @@ -355,13 +345,22 @@ "source": [ "bdata.samples[\"mask\"] = 1\n", "bdata.samples.loc[\n", - " bdata.samples.median_corr_X.isnull() | (bdata.samples.median_corr_X < corr_X_thres), \"mask\"\n", + " bdata.samples.median_corr_X.isnull() | (bdata.samples.median_corr_X < corr_X_thres),\n", + " \"mask\",\n", "] = 0\n", "if \"median_editing_rate\" in bdata.samples.columns.tolist():\n", " bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, \"mask\"] = 0\n", - "bdata_filtered = bdata[\n", - " :, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres\n", - "]" + "if (\n", + " isinstance(replicate_label, str)\n", + " and len(bdata.samples[replicate_label].unique()) > 1\n", + " or isinstance(replicate_label, list)\n", + " and len(bdata.samples[replicate_label].drop_duplicates()) > 1\n", + "):\n", + " bdata_filtered = bdata[\n", + " :, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres\n", + " ]\n", + "else:\n", + " bdata_filtered = bdata" ] }, { diff --git a/tests/test_run.py b/tests/test_run.py index 8d39cc4..baa0a80 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -4,7 +4,7 @@ @pytest.mark.order(13) def test_run_variant_wacc(): - cmd = "bean-run variant tests/data/var_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal_chr6.bw -o tests/test_res/var/ --repguide-mask None" + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal_chr6.bw -o tests/test_res/var/ --repguide-mask None" try: subprocess.check_output( cmd, @@ -17,7 +17,7 @@ def test_run_variant_wacc(): @pytest.mark.order(14) def test_run_variant_noacc(): - cmd = "bean-run variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ " + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ " try: subprocess.check_output( cmd, @@ -30,7 +30,7 @@ def test_run_variant_noacc(): @pytest.mark.order(15) def test_run_variant_wo_negctrl_uniform(): - cmd = "bean-run variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --uniform-edit " + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --uniform-edit " try: subprocess.check_output( cmd, @@ -43,7 +43,7 @@ def test_run_variant_wo_negctrl_uniform(): @pytest.mark.order(16) def test_run_tiling_wo_negctrl(): - cmd = "bean-run tiling tests/data/tiling_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal.bw -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal.bw -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" try: subprocess.check_output( cmd, @@ -56,7 +56,7 @@ def test_run_tiling_wo_negctrl(): @pytest.mark.order(17) def test_run_tiling_with_wo_negctrl_noacc(): - cmd = "bean-run tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" try: subprocess.check_output( cmd, @@ -69,7 +69,7 @@ def test_run_tiling_with_wo_negctrl_noacc(): @pytest.mark.order(18) def test_run_tiling_with_wo_negctrl_uniform(): - cmd = "bean-run tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" try: subprocess.check_output( cmd, From 83686ea4057b5aa9ccfa70f477cb61eb0e17fd84 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 28 Nov 2023 11:02:27 -0500 Subject: [PATCH 09/13] debug return --- bean/preprocessing/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bean/preprocessing/utils.py b/bean/preprocessing/utils.py index 44c6431..8445f13 100644 --- a/bean/preprocessing/utils.py +++ b/bean/preprocessing/utils.py @@ -47,6 +47,7 @@ def prepare_bdata(bdata: be.ReporterScreen, args, warn, prefix: str): f"Ignoring {n_no_support_targets} targets with 0 gRNA counts across all non-control samples. Ignored targets are written in {prefix}/no_support_targets.csv." ) return bdata + return bdata def _get_accessibility_single( From cb5da17cb071343aee822588cf84e7d468fd8c9e Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 28 Nov 2023 11:08:36 -0500 Subject: [PATCH 10/13] remove dependency on survival implementation --- bean/model/run.py | 3 ++- tests/data/tiling_mini_screen_annotated.h5ad | Bin 2131872 -> 2131872 bytes tests/data/var_mini_screen_annotated.h5ad | Bin 2921120 -> 2921120 bytes 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bean/model/run.py b/bean/model/run.py index a1619f4..2af840c 100644 --- a/bean/model/run.py +++ b/bean/model/run.py @@ -8,7 +8,8 @@ from functools import partial import pyro import bean.model.model as sorting_model -import bean.model.survival_model as survival_model + +# import bean.model.survival_model as survival_model logging.basicConfig( level=logging.INFO, diff --git a/tests/data/tiling_mini_screen_annotated.h5ad b/tests/data/tiling_mini_screen_annotated.h5ad index 9eae70fb100ae7fe0694e0815de3b88d2cc67d08..7e58ea796e4f17aab1221d45a886d8e73c3d6424 100644 GIT binary patch delta 221 zcmb`&y%9rT0D$2W5j9fOYIO+d{*1jUz6bZ%IOc`BhSb4g)3*uwcW1iz80(aE6ZnAtGFG#SJm;cpyQF3{T`JP!38~ GVr8 zeD;6(6ols>j1VP8oCHZyq{)yaN1g&jN|dQkrAD0wOoCQl( Ntl6-A1>x@S{sA{VcKQGS From a7ea8ace725ad6e33a05de7bbbac613867483a10 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 28 Nov 2023 11:09:13 -0500 Subject: [PATCH 11/13] remove dependency on survival implementation --- bean/model/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bean/model/run.py b/bean/model/run.py index 2af840c..68fc672 100644 --- a/bean/model/run.py +++ b/bean/model/run.py @@ -397,8 +397,8 @@ def run_inference( def identify_model_guide(args): if args.selection == "sorting": m = sorting_model - else: - m = survival_model + # else: + # m = survival_model if args.library_design == "tiling": info("Using Mixture Normal model...") return ( From 599376acc2187d10f6d7382539fb10cac2fcc2aa Mon Sep 17 00:00:00 2001 From: Jayoung Kim Ryu Date: Tue, 28 Nov 2023 11:30:34 -0500 Subject: [PATCH 12/13] Include codecov --- .github/workflows/CI.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1075c23..313f853 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,3 +34,7 @@ jobs: run: | pip install pytest pytest --sparse-ordering + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} From 227944be12008fa60fa5dceb1266cedf0583db07 Mon Sep 17 00:00:00 2001 From: Jayoung Kim Ryu Date: Tue, 28 Nov 2023 11:55:45 -0500 Subject: [PATCH 13/13] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 313f853..0da743b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -8,7 +8,7 @@ name: CI -on: [push] +on: [push, workflow_dispatch] permissions: contents: read