From e5b3a90628931974494a4460649e2f12fb266618 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Fri, 23 Aug 2024 22:22:18 +0000 Subject: [PATCH] allow user-provided control-conditions --- bean/annotate/utils.py | 10 ++ bean/cli/count.py | 2 + bean/cli/count_samples.py | 7 +- bean/cli/run.py | 2 +- bean/model/parser.py | 5 +- bean/model/run.py | 12 +- bean/model/survival_model.py | 74 +++++++----- bean/notebooks/sample_quality_report.ipynb | 31 +++++ bean/preprocessing/data_class.py | 129 ++++++++++++--------- bean/qc/sample_qc.py | 14 ++- bean/qc/utils.py | 3 +- 11 files changed, 199 insertions(+), 90 deletions(-) diff --git a/bean/annotate/utils.py b/bean/annotate/utils.py index f7d5944..e912fb9 100755 --- a/bean/annotate/utils.py +++ b/bean/annotate/utils.py @@ -450,6 +450,16 @@ def check_args(args): raise ValueError( "Invalid arguments: You should specify exactly one of --translate-fasta, --translate-fastas-csv, --translate-gene, translate-genes-list to translate alleles." ) + if ( + args.translate_fasta is not None + or args.translate_fastas_csv is not None + or args.translate_gene is not None + or args.translate_genes_list is not None + ) and not args.translate: + warn( + "fastq or gene files for translation provided without `--translate` flag. Setting `--translate` flag to True." + ) + args.translate = True if args.translate_genes_list is not None: args.translate_genes_list = ( pd.read_csv(args.translate_genes_list, header=None).values[:, 0].tolist() diff --git a/bean/cli/count.py b/bean/cli/count.py index 53aed0d..7579977 100755 --- a/bean/cli/count.py +++ b/bean/cli/count.py @@ -72,6 +72,8 @@ def main(args): ] if match_target_pos: counter.screen.get_edit_mat_from_uns(target_base_edits, match_target_pos) + else: + counter.screen.get_edit_mat_from_uns(target_base_edits) counter.screen.write(f"{counter.output_dir}.h5ad") counter.screen.to_Excel(f"{counter.output_dir}.xlsx") info(f"Output written at:\n {counter.output_dir}.h5ad,\n {counter.output_dir}.xlsx") diff --git a/bean/cli/count_samples.py b/bean/cli/count_samples.py index 133eaf8..3554368 100755 --- a/bean/cli/count_samples.py +++ b/bean/cli/count_samples.py @@ -92,12 +92,15 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace): screen = counter.screen if screen.X.max() == 0: warn(f"Nothing counted for {sample_id}. Check your input.") - if counter.count_reporter_edits and match_target_pos: + if counter.count_reporter_edits: screen.uns["allele_counts"] = screen.uns["allele_counts"].loc[ screen.uns["allele_counts"].allele.map(str) != "", : ] screen.get_edit_from_allele("allele_counts", "allele") - screen.get_edit_mat_from_uns(target_base_edits, match_target_pos) + if match_target_pos: + screen.get_edit_mat_from_uns(target_base_edits, match_target_pos) + else: + screen.get_edit_mat_from_uns(target_base_edits) info( f"Done for {sample_id}. \n\ Output written at {counter.output_dir}.h5ad" diff --git a/bean/cli/run.py b/bean/cli/run.py index dd17a56..5c51c54 100755 --- a/bean/cli/run.py +++ b/bean/cli/run.py @@ -85,8 +85,8 @@ def main(args, return_data=False): file_logger = logging.FileHandler(f"{prefix}/bean_run.log") file_logger.setLevel(logging.INFO) logging.getLogger().addHandler(file_logger) + info(f"Running: {' '.join(sys.argv[:])}") if args.cuda: - os.environ["CUDA_VISIBLE_DEVICES"] = "1" torch.set_default_tensor_type(torch.cuda.FloatTensor) else: torch.set_default_tensor_type(torch.FloatTensor) diff --git a/bean/model/parser.py b/bean/model/parser.py index 0b41961..144ecd2 100755 --- a/bean/model/parser.py +++ b/bean/model/parser.py @@ -121,9 +121,8 @@ def parse_args(parser=None): "--control-condition", default="bulk", type=str, - help="Value in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used. Select this as the condition with the least selection- For the sorting screen, use presort (bulk). For the survival screens, use the closest one with T=0.", + help="Comma-separated list of condition values in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used.", ) - input_parser.add_argument( "--plasmid-condition", default="bulk", @@ -167,7 +166,7 @@ def parse_args(parser=None): "--sample-mask-col", type=str, default="mask", - help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.", + help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples. If you don't want to mask samples out, provide `--sample-mask-col=''`.", ) input_parser.add_argument( diff --git a/bean/model/run.py b/bean/model/run.py index ca52cc2..412c183 100755 --- a/bean/model/run.py +++ b/bean/model/run.py @@ -130,6 +130,8 @@ def check_args(args, bdata): warn( f"{args.bdata_path} does not have replicate x guide outlier mask. All guides are included in analysis." ) + if args.sample_mask_col == "": + args.sample_mask_col = None if args.sample_mask_col is not None: if args.sample_mask_col not in bdata.samples.columns.tolist(): raise ValueError( @@ -139,10 +141,16 @@ def check_args(args, bdata): raise ValueError( f"Condition column `{args.condition_col}` set by `--condition-col` not in ReporterScreen.samples.columns:{bdata.samples.columns}. Check your input." ) - if args.control_condition not in bdata.samples[args.condition_col].tolist(): + if args.selection == "survival" and args.condition_col == args.time_col: raise ValueError( - f"No sample has control label `{args.control_condition}` (set by `--control-condition`) in ReporterScreen.samples[{args.condition_col}]: {bdata.samples[args.condition_col]}. Check your input. For the selection of this argument, see more in `--condition-col` under `bean run --help`." + f"Invalid to have the same `--condition-col` ({args.condition_col}) and `--time-col` ({args.time_col})." ) + control_condits = args.control_condition.split(",") + for control_condit in control_condits: + if control_condit not in bdata.samples[args.condition_col].astype(str).tolist(): + raise ValueError( + f"No sample has control label `{args.control_condition}` (set by `--control-condition`) in ReporterScreen.samples[{args.condition_col}]: {bdata.samples[args.condition_col]}. Check your input. For the selection of this argument, see more in `--condition-col` under `bean run --help`." + ) if args.replicate_col not in bdata.samples.columns: raise ValueError( f"Condition column set by `--replicate-col` {args.replicate_col} not in ReporterScreen.samples.columns:{bdata.samples.columns}. Check your input." diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 4dc1434..f30388e 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -219,6 +219,7 @@ def MixtureNormalModel( data: VariantSurvivalReporterScreenData, alpha_prior: float = 1, use_bcmatch: bool = True, + use_all_timepoints_for_pi: bool = True, sd_scale: float = 0.01, scale_by_accessibility: bool = False, fit_noise: bool = False, @@ -231,6 +232,7 @@ def MixtureNormalModel( data: Input data of type VariantSortingReporterScreenData. alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1. use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True. + use_all_timepoints_for_pi: Use all available timepoints instead of the `--control-condition` timepoint. sd_scale: Scale for the prior standard deviation. Defaults to 0.01. scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility. fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise. @@ -295,28 +297,34 @@ def MixtureNormalModel( pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1) ), ) - with time_plate: + assert pi.shape == ( + data.n_reps, + 1, + data.n_guides, + 2, + ), pi.shape + with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2): with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): - time_pi = data.timepoints - # Accounting for sample specific overall edit rate across all guides. - # P(allele | guide, bin=bulk) - assert pi.shape == ( - data.n_reps, - 1, - data.n_guides, - 2, - ), pi.shape + # if use_all_timepoints_for_pi: + # time_pi = data.timepoints + # expanded_allele_p = pi * r.expand( + # data.n_reps, len(data.timepoints), -1, -1 + # ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( + # data.n_reps, len(data.timepoints), -1, -1 + # ) + # pyro.sample( + # "allele_count", + # dist.Multinomial(probs=expanded_allele_p, validate_args=False), + # obs=data.allele_counts, + # ) + # else: + time_pi = data.control_timepoint # If pi is sampled in later timepoint, account for the selection. - - expanded_allele_p = pi * r.expand( - data.n_reps, len(data.timepoints), -1, -1 - ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( - data.n_reps, len(data.timepoints), -1, -1 - ) + expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi pyro.sample( - "allele_count", + "control_allele_count", dist.Multinomial(probs=expanded_allele_p, validate_args=False), - obs=data.allele_counts, + obs=data.allele_counts_control, ) if scale_by_accessibility: # Endogenous target site editing rate may be different @@ -396,6 +404,7 @@ def MultiMixtureNormalModel( data: TilingSurvivalReporterScreenData, alpha_prior=1, use_bcmatch=True, + use_all_timepoints_for_pi: bool = True, sd_scale=0.01, norm_pi=False, scale_by_accessibility=False, @@ -410,6 +419,7 @@ def MultiMixtureNormalModel( data: Input data of type VariantSortingReporterScreenData. alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1. use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True. + use_all_timepoints_for_pi: Use all available timepoints instead of the `--control-condition` timepoint. sd_scale: Scale for the prior standard deviation. Defaults to 0.01. scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility. fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise. @@ -486,25 +496,35 @@ def MultiMixtureNormalModel( with replicate_plate: with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): - time_pi = data.control_timepoint pi = pyro.sample( "pi", dist.Dirichlet( pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1) ), ) - with time_plate: + with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2): with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): + # if use_all_timepoints_for_pi: + # time_pi = data.timepoints + # # If pi is sampled in later timepoint, account for the selection. + # expanded_allele_p = pi * r.expand( + # data.n_reps, 1, -1, -1 + # ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( + # data.n_reps, len(data.timepoints), -1, -1 + # ) + # pyro.sample( + # "allele_count", + # dist.Multinomial(probs=expanded_allele_p, validate_args=False), + # obs=data.allele_counts, + # ) + # else: + time_pi = data.control_timepoint # If pi is sampled in later timepoint, account for the selection. - expanded_allele_p = pi * r.expand( - data.n_reps, 1, -1, -1 - ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( - data.n_reps, len(data.timepoints), -1, -1 - ) + expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi pyro.sample( - "allele_count", + "control_allele_count", dist.Multinomial(probs=expanded_allele_p, validate_args=False), - obs=data.allele_counts, + obs=data.allele_counts_control, ) if scale_by_accessibility: # Endogenous target site editing rate may be different diff --git a/bean/notebooks/sample_quality_report.ipynb b/bean/notebooks/sample_quality_report.ipynb index 1a19597..b8df76b 100755 --- a/bean/notebooks/sample_quality_report.ipynb +++ b/bean/notebooks/sample_quality_report.ipynb @@ -93,6 +93,7 @@ "bdata.uns[\"reporter_length\"] = reporter_length\n", "bdata.uns[\"reporter_right_flank_length\"] = reporter_right_flank_length\n", "if posctrl_col:\n", + " bdata.guides[posctrl_col] = bdata.guides[posctrl_col].astype(str)\n", " if posctrl_col not in bdata.guides.columns:\n", " raise ValueError(f\"--posctrl-col argument '{posctrl_col}' is not present in the input ReporterScreen.guides.columns {bdata.guides.columns}. If you do not want to use positive control gRNA annotation for LFC calculation, feed --posctrl-col='' instead.\")\n", " if posctrl_val not in bdata.guides[posctrl_col].tolist():\n", @@ -325,6 +326,13 @@ " )" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Editing rate" + ] + }, { "cell_type": "code", "execution_count": null, @@ -345,6 +353,29 @@ " be.qc.plot_guide_edit_rates(bdata)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### R1-R2 recombination" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(\n", + " 1-(\n", + " bdata[:, bdata.samples.condition == ctrl_cond].layers[\"X_bcmatch\"]\n", + " / bdata[:, bdata.samples.condition == ctrl_cond].X\n", + " ).mean(axis=1)\n", + ")\n", + "plt.xlabel(\"Recombination rate\")\n", + "plt.ylabel(\"Frequency\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 85501b6..7a0711a 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -90,9 +90,11 @@ def __init__( ) replicate_column = "_rc" self.screen = screen + control_condition = control_condition.split(",") if not control_can_be_selected: self.screen_selected = screen[ - :, screen.samples[condition_column] != control_condition + :, + ~(screen.samples[condition_column].astype(int).isin(control_condition)), ] else: self.screen_selected = screen[:, ~screen.samples[condition_column].isnull()] @@ -101,7 +103,7 @@ def __init__( self.screen_selected.var[condition_column].unique() ) # excluding bulk self.screen_control = screen[ - :, screen.samples[condition_column] == control_condition + :, screen.samples[condition_column].astype(str).isin(control_condition) ] self.n_samples = len(screen.samples) # 8 self.n_guides = len(screen.guides) @@ -134,18 +136,22 @@ def _post_init(self): self.sample_mask = torch.as_tensor( self.screen_selected.samples[self.sample_mask_column].to_numpy() ).reshape(self.n_reps, self.n_condits) - self.bulk_sample_mask = torch.as_tensor( + self.control_sample_mask = torch.as_tensor( self.screen_control.samples[self.sample_mask_column].to_numpy() - ) + ).reshape(self.n_reps, len(self.control_condition)) else: self.sample_mask = torch.ones((self.n_reps, self.n_condits), dtype=Boolean) - self.bulk_sample_mask = torch.ones(self.n_reps, dtype=Boolean) + self.control_sample_mask = torch.ones( + (self.n_reps, len(self.control_condition)), dtype=Boolean + ) self.X = self.transform_data( self.screen_selected.X ) # (n_reps, n_bins, n_guides) self.X_masked = self.X * self.sample_mask[:, :, None] - self.X_control = self.transform_data(self.screen_control.X, 1) - self.X_control_masked = self.X_control * self.bulk_sample_mask[:, None, None] + self.X_control = self.transform_data( + self.screen_control.X, len(self.control_condition) + ) + self.X_control_masked = self.X_control * self.control_sample_mask[:, :, None] if self.repguide_mask_key is None: self.repguide_mask = ~(self.X == 0).any(axis=1) else: @@ -182,7 +188,7 @@ def _post_init(self): ).reshape(self.n_reps, self.n_condits) self.size_factor_control = torch.as_tensor( self.screen_control.samples["size_factor"].to_numpy() - ).reshape(self.n_reps, 1) + ).reshape(self.n_reps, len(self.control_condition)) # Get a0 fitted_a0, self.popt = get_fitted_alpha0( self.X.clone().cpu(), @@ -319,25 +325,31 @@ def _post_init( self.X_bcmatch = self.transform_data(self.screen_selected.layers["X_bcmatch"]) self.X_bcmatch_masked = self.X_bcmatch * self.sample_mask[:, :, None] self.X_bcmatch_control = self.transform_data( - self.screen_control.layers["X_bcmatch"], 1 + self.screen_control.layers["X_bcmatch"], len(self.control_condition) ) + # print(self.control_sample_mask.shape, self.X_bcmatch_control.shape) torch.Size([3, 2]) torch.Size([3, 2, 11035]) self.X_bcmatch_control_masked = ( - self.X_bcmatch_control * self.bulk_sample_mask[:, None, None] + self.X_bcmatch_control * self.control_sample_mask[:, :, None] ) self.size_factor_bcmatch = torch.as_tensor( self.screen_selected.samples["size_factor_bcmatch"].to_numpy() ).reshape(self.n_reps, self.n_condits) self.size_factor_bcmatch_control = torch.as_tensor( self.screen_control.samples["size_factor_bcmatch"].to_numpy() - ).reshape(self.n_reps, 1) + ).reshape(self.n_reps, len(self.control_condition)) if hasattr(self, "timepoints") and not hasattr(self, "allele_counts"): + assert hasattr(self, "time_column") control_allele_counts = [] for timepoint in self.timepoints: - screen_t = self.screen[:, self.screen.samples.time == timepoint.item()] + screen_t = self.screen[ + :, self.screen.samples[self.time_column] == timepoint.item() + ] edited_control = self.transform_data(screen_t.layers["edits"], n_bins=1) nonedited_control = ( - self.transform_data(screen_t.layers["X_bcmatch"], 1) + self.transform_data( + screen_t.layers["X_bcmatch"], len(self.control_condition) + ) - edited_control ) nonedited_control[nonedited_control < 0] = 0 @@ -347,7 +359,7 @@ def _post_init( self.allele_counts = torch.cat(control_allele_counts, axis=1) edited_control = self.transform_data( - self.screen_control.layers["edits"], n_bins=1 + self.screen_control.layers["edits"], n_bins=len(self.control_condition) ) nonedited_control = self.X_bcmatch_control - edited_control nonedited_control[nonedited_control < 0] = 0 @@ -625,17 +637,17 @@ def _set_uid_to_row(row): self.n_guides, self.n_max_alleles, ) - self.allele_counts_control = self.transform_allele_control( + self.allele_counts_control = self.transform_allele( self.screen_control, reindexed_df ) self.allele_mask = self.get_allele_mask(self.screen_control, guide_to_allele) assert self.allele_mask.shape == (self.n_guides, self.n_max_alleles) assert self.allele_counts_control.shape == ( self.n_reps, - 1, + len(self.control_condition), self.n_guides, self.n_max_alleles, - ) + ), self.allele_counts_control.shape def get_allele_to_edit_tensor( self, @@ -717,23 +729,33 @@ def reindex_allele_df(self, alleles_df, allele_col): reindexed_allele_df = reindexed_df.drop([allele_col, "index"], axis=1) return (guide_allele_id_to_allele, reindexed_allele_df) - def transform_allele(self, adata, reindexed_df): + def transform_allele(self, adata, reindexed_df, sort_condition_by=None): """ Transform reindexed allele dataframe reindexed_df of (guide, allele_id_for_guide) -> (per sample count) to (n_reps, n_bins, n_guides, n_alleles) tensor. """ + if sort_condition_by is None: + if hasattr(self, "time_column"): + sort_condition_by = self.time_column + else: + sort_condition_by = self.condition_column allele_tensor = torch.empty( - (self.n_reps, self.n_condits, self.n_guides, self.n_max_alleles), + ( + self.n_reps, + len(adata.samples[f"{sort_condition_by}_id"].unique()), + self.n_guides, + self.n_max_alleles, + ), ) if self.device is not None: allele_tensor = allele_tensor.cuda() for i in range(self.n_reps): - for j in range(self.n_condits): + for j, cond in enumerate(adata.samples[f"{sort_condition_by}_id"].unique()): condit_idx = np.where( (adata.samples.replicate_id == i) - & (adata.samples[f"{self.condition_column}_id"] == j) + & (adata.samples[f"{sort_condition_by}_id"] == cond) )[0] - assert len(condit_idx) == 1, print(i, j, condit_idx) + assert len(condit_idx) == 1, (i, j, condit_idx) condit_idx = condit_idx.item() condit_name = adata.samples.index[condit_idx] condit_allele_df = ( @@ -747,10 +769,6 @@ def transform_allele(self, adata, reindexed_df): condit_bcmatch_counts = adata.layers["X_bcmatch"][:, condit_idx].astype( int ) - # if not (condit_bcmatch_counts >= condit_allele_df.sum(axis=1)).all(): - # print( - # f"Allele counts are larger than total bcmatch counts in rep {i}, {j} by {(condit_bcmatch_counts - condit_allele_df.sum(axis=1)).min()}." - # ) condit_allele_df[0] = condit_bcmatch_counts - condit_allele_df.loc[ :, condit_allele_df.columns != 0 ].sum(axis=1) @@ -765,11 +783,6 @@ def transform_allele(self, adata, reindexed_df): ) allele_tensor[i, j, :, :] = torch.as_tensor(condit_allele_df.to_numpy()) - try: - assert (allele_tensor >= 0).all(), allele_tensor[allele_tensor < 0] - except AssertionError: - print("Allele tensor doesn't match condit_allele_df") - return (allele_tensor, reindexed_df) return allele_tensor def transform_allele_control(self, adata, reindexed_df): @@ -779,7 +792,12 @@ def transform_allele_control(self, adata, reindexed_df): """ allele_tensor = torch.empty( - (self.n_reps, 1, self.n_guides, self.n_max_alleles), + ( + self.n_reps, + len(self.control_condition), + self.n_guides, + self.n_max_alleles, + ), ) if self.device is not None: allele_tensor = allele_tensor.cuda() @@ -994,7 +1012,7 @@ def __init__( self._post_init() def _pre_init(self, time_column: str, condition_column: str): - self.condition_column = self.time_column = time_column + self.time_column = time_column try: max_time = self.screen.samples[time_column].astype(float).max() self.screen.samples[time_column] = self.screen.samples[time_column].astype( @@ -1032,11 +1050,9 @@ def _pre_init(self, time_column: str, condition_column: str): # def _post_init( # self, # ): - self.timepoints = torch.as_tensor( - self.screen_selected.samples[self.time_column].unique() - ) + control_timepoint = self.screen_control.samples[self.time_column].unique() - if len(control_timepoint) != 1: + if len(control_timepoint) != len(self.control_condition): info(self.screen_control) info(self.screen_control.samples) info(control_timepoint) @@ -1044,24 +1060,30 @@ def _pre_init(self, time_column: str, condition_column: str): "All samples with --control-condition should have the same --time-col column in ReporterScreen.samples[time_col]. Check your input ReporterScreen object." ) else: - self.control_timepoint = control_timepoint[0] + self.control_timepoint = torch.tensor(control_timepoint) self.n_timepoints = self.n_condits - timepoints = self.screen_selected.samples.sort_values(self.time_column)[ + timepoints = self.screen.samples.sort_values(self.time_column)[ self.time_column ].drop_duplicates() if timepoints.isnull().any(): raise ValueError( - f"NaN values in time points provided in input: {self.screen_selected.samples[self.time_column]}" + f"NaN values in time points provided in input: {self.screen.samples[self.time_column]}" ) for j, time in enumerate(timepoints): - self.screen_selected.samples.loc[ - self.screen_selected.samples[self.time_column] == time, + self.screen.samples.loc[ + self.screen.samples[self.time_column] == time, f"{self.time_column}_id", ] = j - self.screen.samples[f"{self.time_column}_id"] = -1 - self.screen.samples.loc[ - self.screen_selected.samples.index, f"{self.time_column}_id" - ] = self.screen_selected.samples[f"{self.time_column}_id"] + self.screen_selected.samples[f"{self.time_column}_id"] = -1 + self.screen_selected.samples[f"{self.time_column}_id"] = ( + self.screen.samples.loc[ + self.screen_selected.samples.index, f"{self.time_column}_id" + ] + ) + self.screen_control.samples[f"{self.time_column}_id"] = -1 + self.screen_control.samples[f"{self.time_column}_id"] = self.screen.samples.loc[ + self.screen_control.samples.index, f"{self.time_column}_id" + ] self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.time_column ) @@ -1078,6 +1100,9 @@ def _pre_init(self, time_column: str, condition_column: str): self.screen_control, self.replicate_column, ) + self.timepoints = torch.as_tensor( + self.screen_selected.samples[self.time_column].unique() + ) @dataclass @@ -1171,17 +1196,17 @@ def set_bcmatch(self, screen): self.X_bcmatch = self.transform_data(self.screen_selected.layers["X_bcmatch"]) self.X_bcmatch_masked = self.X_bcmatch * self.sample_mask[:, :, None] self.X_bcmatch_control = self.transform_data( - self.screen_control.layers["X_bcmatch"], 1 + self.screen_control.layers["X_bcmatch"], len(self.control_condition) ) self.X_bcmatch_control_masked = ( - self.X_bcmatch_control * self.bulk_sample_mask[:, None, None] + self.X_bcmatch_control * self.control_sample_mask[:, :, None] ) self.size_factor_bcmatch = torch.as_tensor( self.screen_selected.samples["size_factor_bcmatch"].to_numpy() ).reshape(self.n_reps, self.n_condits) self.size_factor_bcmatch_control = torch.as_tensor( self.screen_control.samples["size_factor_bcmatch"].to_numpy() - ).reshape(self.n_reps, 1) + ).reshape(self.n_reps, len(self.control_condition)) a0_bcmatch = get_pred_alpha0( self.X_bcmatch.clone().cpu(), self.size_factor_bcmatch.clone().cpu(), @@ -1369,17 +1394,17 @@ def set_bcmatch(self, screen): self.X_bcmatch = self.transform_data(self.screen_selected.layers["X_bcmatch"]) self.X_bcmatch_masked = self.X_bcmatch * self.sample_mask[:, :, None] self.X_bcmatch_control = self.transform_data( - self.screen_control.layers["X_bcmatch"], 1 + self.screen_control.layers["X_bcmatch"], len(self.control_condition) ) self.X_bcmatch_control_masked = ( - self.X_bcmatch_control * self.bulk_sample_mask[:, None, None] + self.X_bcmatch_control * self.control_sample_mask[:, None, None] ) self.size_factor_bcmatch = torch.as_tensor( self.screen_selected.samples["size_factor_bcmatch"].to_numpy() ).reshape(self.n_reps, self.n_condits) self.size_factor_bcmatch_control = torch.as_tensor( self.screen_control.samples["size_factor_bcmatch"].to_numpy() - ).reshape(self.n_reps, 1) + ).reshape(self.n_reps, len(self.control_condition)) a0_bcmatch = get_pred_alpha0( self.X_bcmatch.clone().cpu(), self.size_factor_bcmatch.clone().cpu(), diff --git a/bean/qc/sample_qc.py b/bean/qc/sample_qc.py index 8b66f90..de3b2fd 100755 --- a/bean/qc/sample_qc.py +++ b/bean/qc/sample_qc.py @@ -1,4 +1,5 @@ """Calculate sample quality""" + from typing import Literal import numpy as np import seaborn as sns @@ -7,10 +8,19 @@ linestyles = ["solid", "dotted", "dashed", "dashdot"] -def plot_guide_edit_rates(bdata, ax=None, figsize=(5, 3), title="", n_bins=30): +def plot_guide_edit_rates( + bdata, ax=None, figsize=(5, 3), title="", n_bins=30, plot_normed: bool = True +): + """Plot guide edit rates + Args: + plot_normed: Plot normalized edit rates. If False, plot total edit rates in editing window for tiling ReporterScreen with `.tiling == True`. + """ if ax is None: fig, ax = plt.subplots(figsize=figsize) - sns.histplot(bdata.guides.edit_rate, bins=n_bins) + if "edit_rate_norm" in bdata.guides.columns and plot_normed: + sns.histplot(bdata.guides.edit_rate_norm, bins=n_bins) + else: + sns.histplot(bdata.guides.edit_rate, bins=n_bins) ax.set_title(title) ax.set_xlabel("Editing rate") return ax diff --git a/bean/qc/utils.py b/bean/qc/utils.py index 294f1fb..d59d096 100755 --- a/bean/qc/utils.py +++ b/bean/qc/utils.py @@ -34,12 +34,13 @@ def check_args(args): raise ValueError( f"Specified --posctrl-col `{args.posctrl_col}` does not exist in ReporterScreen.guides.columns ({bdata.guides.columns}). Please check your input." ) + bdata.guides[args.posctrl_col] = bdata.guides[args.posctrl_col].astype(str) if ( args.posctrl_col and args.posctrl_val not in bdata.guides[args.posctrl_col].tolist() ): raise ValueError( - f"Specified --control-condition `{args.posctrl_val}` does not exist in ReporterScreen.guides[{args.posctrl_col}] ({bdata.guides[args.posctrl_col]}). Please check your input." + f"Specified --posctrl-val `{args.posctrl_val}` does not exist in ReporterScreen.guides[{args.posctrl_col}] ({bdata.guides[args.posctrl_col]}). Please check your input. To proceed without positive control, please provide --posctrl-col='' argument." ) if args.control_condition not in bdata.samples[args.condition_col].tolist(): raise ValueError(