From ac070b609b2124c15bfb3b03c978c12e9034080c Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Mon, 26 Aug 2024 17:56:45 +0000 Subject: [PATCH] debug sample ordering --- bean/notebooks/sample_quality_report.ipynb | 21 ++++++++----- bean/preprocessing/data_class.py | 34 +++++++++++++++++----- bean/qc/utils.py | 12 ++++---- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/bean/notebooks/sample_quality_report.ipynb b/bean/notebooks/sample_quality_report.ipynb index b8df76b..98e294c 100755 --- a/bean/notebooks/sample_quality_report.ipynb +++ b/bean/notebooks/sample_quality_report.ipynb @@ -366,14 +366,19 @@ "metadata": {}, "outputs": [], "source": [ - "plt.hist(\n", - " 1-(\n", - " bdata[:, bdata.samples.condition == ctrl_cond].layers[\"X_bcmatch\"]\n", - " / bdata[:, bdata.samples.condition == ctrl_cond].X\n", - " ).mean(axis=1)\n", - ")\n", - "plt.xlabel(\"Recombination rate\")\n", - "plt.ylabel(\"Frequency\")" + "if \"target_base_changes\" not in bdata.uns or not base_edit_data:\n", + " print(\n", + " \"Not a base editing data or target base change not provided. Passing editing-related QC\"\n", + " )\n", + "elif \"edits\" in bdata.layers.keys():\n", + " plt.hist(\n", + " 1-np.nanmean(\n", + " bdata[:, bdata.samples.condition == ctrl_cond].layers[\"X_bcmatch\"]\n", + " / bdata[:, bdata.samples.condition == ctrl_cond].X\n", + " , axis=1)\n", + " )\n", + " plt.xlabel(\"Recombination rate\")\n", + " plt.ylabel(\"Frequency\")" ] }, { diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 8f8a52a..2b1d63c 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -105,6 +105,7 @@ def __init__( self.screen_control = screen[ :, screen.samples[condition_column].astype(str).isin(control_condition) ] + self.control_can_be_selected = control_can_be_selected self.n_samples = len(screen.samples) # 8 self.n_guides = len(screen.guides) self.n_reps = len(screen.samples[replicate_column].unique()) @@ -328,6 +329,7 @@ def _post_init( self.screen_control.layers["X_bcmatch"], len(self.control_condition) ) # print(self.control_sample_mask.shape, self.X_bcmatch_control.shape) torch.Size([3, 2]) torch.Size([3, 2, 11035]) + # print("bcm init @ ReporterScreen", self.X_bcmatch_control_masked.shape) self.X_bcmatch_control_masked = ( self.X_bcmatch_control * self.control_sample_mask[:, :, None] ) @@ -972,13 +974,26 @@ def _pre_init( .values.astype(int) ) ) - self.screen_selected = _assign_rep_ids_and_sort( - self.screen_selected, self.replicate_column, self.condition_column - ) - self.screen_control = _assign_rep_ids_and_sort( - self.screen_control, - self.replicate_column, - ) + if not self.control_can_be_selected: + self.screen_selected = screen[ + :, + ~( + self.screen.samples[self.condition_column] + .astype(str) + .isin(self.control_condition) + ), + ] + else: + self.screen_selected = self.screen[ + :, ~self.screen.samples[self.condition_column].isnull() + ] + + self.screen_control = self.screen[ + :, + self.screen.samples[self.condition_column] + .astype(str) + .isin(self.control_condition), + ] @dataclass @@ -1370,10 +1385,13 @@ def __getitem__(self, guide_idx): if hasattr(ndata, "X_bcmatch"): ndata.X_bcmatch = ndata.X_bcmatch[:, :, guide_idx] if hasattr(ndata, "X_bcmatch_masked"): + print("b shape", ndata.X_bcmatch.shape) ndata.X_bcmatch_masked = ndata.X_bcmatch_masked[:, :, guide_idx] if hasattr(ndata, "X_bcmatch_control"): + print("bc shape", ndata.X_bcmatch_control.shape) ndata.X_bcmatch_control = ndata.X_bcmatch_control[:, :, guide_idx] if hasattr(ndata, "X_bcmatch_control_masked"): + print("bcm shape", ndata.X_bcmatch_control_masked.shape) ndata.X_bcmatch_control_masked = ndata.X_bcmatch_control_masked[ :, :, guide_idx ] @@ -1395,7 +1413,7 @@ def set_bcmatch(self, screen): self.screen_control.layers["X_bcmatch"], len(self.control_condition) ) self.X_bcmatch_control_masked = ( - self.X_bcmatch_control * self.control_sample_mask[:, None, None] + self.X_bcmatch_control * self.control_sample_mask[:, :, None] ) self.size_factor_bcmatch = torch.as_tensor( self.screen_selected.samples["size_factor_bcmatch"].to_numpy() diff --git a/bean/qc/utils.py b/bean/qc/utils.py index d59d096..f91bf5b 100755 --- a/bean/qc/utils.py +++ b/bean/qc/utils.py @@ -30,11 +30,13 @@ def check_args(args): raise ValueError( f"Specified --target-pos-col `{args.target_pos_col}` does not exist in ReporterScreen.guides.columns ({bdata.guides.columns}). Please check your input. (--tiling {args.tiling}, ReporterScreen.tiling: {bdata.tiling})" ) - if args.posctrl_col and args.posctrl_col not in bdata.guides.columns: - raise ValueError( - f"Specified --posctrl-col `{args.posctrl_col}` does not exist in ReporterScreen.guides.columns ({bdata.guides.columns}). Please check your input." - ) - bdata.guides[args.posctrl_col] = bdata.guides[args.posctrl_col].astype(str) + if args.posctrl_col != "": + if args.posctrl_col and args.posctrl_col not in bdata.guides.columns: + raise ValueError( + f"Specified --posctrl-col `{args.posctrl_col}` does not exist in ReporterScreen.guides.columns ({bdata.guides.columns}). Please check your input." + ) + else: + bdata.guides[args.posctrl_col] = bdata.guides[args.posctrl_col].astype(str) if ( args.posctrl_col and args.posctrl_val not in bdata.guides[args.posctrl_col].tolist()