From 509ad9ef64b79b52eba0b7dc0d3f5cd8fd3a5c7d Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Mon, 6 May 2024 23:41:48 -0400 Subject: [PATCH] allow different construct for QC & add per-threshold PASS/FAIL output --- bean/cli/qc.py | 6 ++- bean/framework/ReporterScreen.py | 10 +++- bean/notebooks/sample_quality_report.ipynb | 61 +++++++++++++++++----- bean/preprocessing/get_alpha0.py | 7 ++- bean/preprocessing/utils.py | 24 +++++++-- bean/qc/parser.py | 19 +++++-- bean/qc/utils.py | 38 ++++++++++++++ 7 files changed, 139 insertions(+), 26 deletions(-) diff --git a/bean/cli/qc.py b/bean/cli/qc.py index 31ab475..8c0c736 100755 --- a/bean/cli/qc.py +++ b/bean/cli/qc.py @@ -41,8 +41,8 @@ def main(args): posctrl_col=args.posctrl_col, posctrl_val=args.posctrl_val, lfc_thres=args.lfc_thres, - replicate_label=args.replicate_label, - condition_label=args.condition_label, + replicate_label=args.replicate_col, + condition_label=args.condition_col, comp_cond1=args.lfc_cond1, comp_cond2=args.lfc_cond2, ctrl_cond=args.control_condition, @@ -50,6 +50,8 @@ def main(args): recalculate_edits=(not args.dont_recalculate_edits), base_edit_data=args.base_edit_data, remove_bad_replicates=args.remove_bad_replicates, + reporter_length=args.reporter_length, + reporter_right_flank_length=args.reporter_right_flank_length, ), kernel_name="bean_python3", ) diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 3d4794c..f7e25fd 100755 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -356,6 +356,8 @@ def get_edit_mat_from_uns( rel_pos_is_reporter=False, target_pos_col="target_pos", edit_count_key="edit_counts", + reporter_length: int = 32, + reporter_right_flank_length: int = 6, ): """ Get the edit matrix from `.uns[edit_count_key]` to store the result in `.layers["edits"]` @@ -374,6 +376,10 @@ def get_edit_mat_from_uns( target_base_edit = self.target_base_changes if match_target_position is None: match_target_position = not self.tiling + if "reporter_length" in self.uns: + reporter_length = self.uns["reporter_length"] + if "reproter_right_flank_length" in self.uns: + reporter_right_flank_length = self.uns["reporter_right_flank_length"] if edit_count_key not in self.uns or len(self.uns[edit_count_key]) == 0: raise ValueError( "Edit count isn't calculated. " @@ -400,7 +406,9 @@ def get_edit_mat_from_uns( drop=True ) edits["guide_start_pos"] = ( - 32 - 6 - guide_len[edits.guide_idx].reset_index(drop=True) + reporter_length + - reporter_right_flank_length + - guide_len[edits.guide_idx].reset_index(drop=True) ) if not match_target_position: edits["rel_pos"] = edits.edit.map(lambda e: e.rel_pos) diff --git a/bean/notebooks/sample_quality_report.ipynb b/bean/notebooks/sample_quality_report.ipynb index 03153c7..34c2577 100755 --- a/bean/notebooks/sample_quality_report.ipynb +++ b/bean/notebooks/sample_quality_report.ipynb @@ -60,7 +60,9 @@ "recalculate_edits = True\n", "tiling = None\n", "base_edit_data = True\n", - "remove_bad_replicates = False" + "remove_bad_replicates = False\n", + "reporter_length = 32\n", + "reporter_right_flank_length = 6" ] }, { @@ -84,6 +86,8 @@ " tiling = bdata.uns['tiling']\n", "else:\n", " raise ValueError(\"Ambiguous assignment if the screen is a tiling screen. Provide `--tiling=True` or `tiling=False`.\")\n", + "bdata.uns[\"reporter_length\"] = reporter_length\n", + "bdata.uns[\"reporter_right_flank_length\"] = reporter_right_flank_length\n", "if posctrl_col:\n", " if posctrl_col not in bdata.guides.columns:\n", " raise ValueError(f\"--posctrl-col argument '{posctrl_col}' is not present in the input ReporterScreen.guides.columns {bdata.guides.columns}. If you do not want to use positive control gRNA annotation for LFC calculation, feed --posctrl-col='' instead.\")\n", @@ -128,6 +132,18 @@ "bdata.samples" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for qc_col in [\"gini_X\", \"median_corr_X\", f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\",\"mean_editing_rate\", \"mask\"]:\n", + " if qc_col in bdata.samples:\n", + " del bdata.samples[qc_col]\n", + "n_cols_samples = len(bdata.samples.columns)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -384,25 +400,28 @@ "metadata": {}, "outputs": [], "source": [ - "bdata.samples[\"mask\"] = 1\n", + "mdata = bdata.samples.copy()\n", "# Data has positive control\n", - "bdata.samples.loc[\n", + "for col in mdata.columns.tolist():\n", + " mdata[col]=1.0\n", + "\n", + "mdata.loc[\n", " bdata.samples.median_corr_X.isnull() | (bdata.samples.median_corr_X < count_correlation_thres),\n", - " \"mask\",\n", - "] = 0\n", + " \"median_corr_X\",\n", + "] = 0.0\n", "if \"mean_editing_rate\" in bdata.samples.columns.tolist():\n", - " bdata.samples.loc[bdata.samples.mean_editing_rate < edit_rate_thres, \"mask\"] = 0\n", + " mdata.loc[bdata.samples.mean_editing_rate < edit_rate_thres, \"mean_editing_rate\"] = 0\n", "\n", - "bdata.samples.loc[\n", + "mdata.loc[\n", " bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] < lfc_thres,\n", - " \"mask\",\n", - "] = 0\n", + " f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\",\n", + "] = 0.0\n", "if posctrl_col:\n", " print(\"filter with posctrl LFC\")\n", - " bdata.samples.loc[\n", + " mdata.loc[\n", " bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"].isnull(),\n", - " \"mask\",\n", - " ] = 0" + " f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\",\n", + " ] = 0.0\n" ] }, { @@ -411,7 +430,19 @@ "metadata": {}, "outputs": [], "source": [ - "bdata.samples.style.background_gradient(cmap=\"coolwarm_r\")" + "from matplotlib import colors\n", + "def b_g(s, cmap='coolwarm_r', low=0, high=1):\n", + " a = mdata.loc[:,s.name].copy()\n", + " if s.name not in mdata.columns.tolist()[n_cols_samples:]:\n", + " a[:] = 1.0\n", + " # rng = a.max() - a.min()\n", + " # norm = colors.Normalize(a.min() - (rng * low),\n", + " # a.max() + (rng * high))\n", + " # normed = norm(a.values)\n", + " c = [colors.rgb2hex(x) for x in plt.cm.get_cmap(cmap)(a.values)]\n", + " return ['background-color: %s' % color for color in c]\n", + "print(\"Failing QC is shown as red:\")\n", + "bdata.samples.style.apply(b_g)" ] }, { @@ -421,6 +452,10 @@ "outputs": [], "source": [ "# leave replicate with more than 1 sorting bin data\n", + "print(mdata)\n", + "print(n_cols_samples)\n", + "\n", + "bdata.samples[\"mask\"] = mdata.iloc[:,n_cols_samples:].astype(int).all(axis=1).astype(int).tolist()\n", "if remove_bad_replicates:\n", " rep_n_samples = bdata.samples.groupby(replicate_label)[\"mask\"].sum()\n", " print(rep_n_samples)\n", diff --git a/bean/preprocessing/get_alpha0.py b/bean/preprocessing/get_alpha0.py index f73f454..fe2be18 100755 --- a/bean/preprocessing/get_alpha0.py +++ b/bean/preprocessing/get_alpha0.py @@ -89,6 +89,7 @@ def get_fitted_alpha0( sample_mask = torch.ones((n_reps, n_condits), device="cpu") elif (sample_mask.sum(axis=0) == 0).any(): raise ValueError("Some bins have no data.") + print(sample_size_factors) w = get_w(X + 1, sample_size_factors, sample_mask=sample_mask) q = get_q(X + 1, sample_size_factors, sample_mask=sample_mask) n = ( @@ -100,12 +101,14 @@ def get_fitted_alpha0( a0 = ((n - 1) / (r - 1 + 1 / (1 - p)) - 1).mean(axis=0) x, y = get_valid_vals(n.log(), a0.log()) - if len(y) < 10: + if len(y) < 5: if popt is None: popt = (-1.510, 0.7861) print( - f"Cannot fit log(a0) ~ log(q): data too sparse! Using pre-fitted values [b0, b1]={popt}" + f"Cannot fit log(a0) ~ log(q): data too sparse ({len(y)} valid values)! Using pre-fitted values [b0, b1]={popt}" ) + print(n) + print(a0) else: popt, pcov = curve_fit(linear, x, y) print("Linear fit of log(a0) ~ log(q): [b0, b1]={}, cov={}".format(popt, pcov)) diff --git a/bean/preprocessing/utils.py b/bean/preprocessing/utils.py index 42fc101..4e3d290 100755 --- a/bean/preprocessing/utils.py +++ b/bean/preprocessing/utils.py @@ -26,17 +26,31 @@ def prepare_bdata(bdata: be.ReporterScreen, args, warn, prefix: str): bdata = bdata.copy() bdata.samples["replicate"] = bdata.samples[args.replicate_col].astype("category") bdata.guides = bdata.guides.loc[:, ~bdata.guides.columns.duplicated()].copy() + + # filter out 0-count gRNAs & samples + if args.selection == "sorting" or args.exclude_control_condition_for_inference: + bdata_test = bdata[ + :, bdata.samples[args.condition_col] != args.control_condition + ] + else: + bdata_test = bdata + if any(bdata_test.X.sum(axis=1) == 0): + warn( + f"Filtering out {sum(bdata_test.X.sum(axis=1) == 0)} gRNAs without any counts over all samples." + ) + bdata = bdata[bdata_test.X.sum(axis=1) > 0, :] + if any(bdata[:, bdata.samples.mask == 1].X.sum(axis=0) == 0): + raise ValueError( + f"Sample {bdata.samples.index[(bdata.samples.mask == 1) & (bdata[:,bdata.samples.mask == 1].X.sum(axis=0) == 0)]} has 0 counts. Make sure you mask that sample." + ) + if args.library_design == "variant": if bdata.guides[args.target_col].isnull().any(): raise ValueError( f"Some target column (bdata.guides[{args.target_col}]) value is null. Check your input file." ) bdata = bdata[bdata.guides[args.target_col].argsort(), :] - if any(bdata.X.sum(axis=1) > 0): - warn( - f"Filtering out {sum(bdata.X.sum(axis=1) > 0)} gRNAs without any counts over all samples." - ) - bdata = bdata[bdata.X.sum(axis=1) > 0, :] + n_no_support_targets, bdata = filter_no_info_target( bdata, condit_col=args.condition_col, diff --git a/bean/qc/parser.py b/bean/qc/parser.py index dd41f66..69c4c8e 100755 --- a/bean/qc/parser.py +++ b/bean/qc/parser.py @@ -84,7 +84,7 @@ def parse_args(parser=None): help="Specify that the guide library is tiling library without 'n guides per target' design", ) input_parser.add_argument( - "--replicate-label", + "--replicate-col", help="Label of column in `bdata.samples` that describes replicate ID.", type=str, default="replicate", @@ -96,7 +96,7 @@ def parse_args(parser=None): default=None, ) input_parser.add_argument( - "--condition-label", + "--condition-col", help="Label of column in `bdata.samples` that describes experimental condition. (sorting bin, time, etc.)", type=str, default="condition", @@ -117,11 +117,13 @@ def parse_args(parser=None): "--edit-start-pos", help="Edit start position to quantify editing rate on, 0-based inclusive.", default=2, + type=int, ) input_parser.add_argument( "--edit-end-pos", help="Edit end position to quantify editing rate on, 0-based exclusive.", default=7, + type=int, ) input_parser.add_argument( @@ -149,5 +151,16 @@ def parse_args(parser=None): type=str, default="bulk", ) - + parser.add_argument( + "--reporter-length", + type=int, + default=32, + help="Length of reporter sequence in the construct.", + ) + parser.add_argument( + "--reporter-right-flank-length", + type=int, + default=6, + help="Length of the right-flanking nucleotides of protospacer in the reporter.", + ) return parser diff --git a/bean/qc/utils.py b/bean/qc/utils.py index b56738a..7ca4b90 100755 --- a/bean/qc/utils.py +++ b/bean/qc/utils.py @@ -3,9 +3,39 @@ import pandas as pd from copy import deepcopy from bean.framework.ReporterScreen import ReporterScreen, concat +import bean as be def check_args(args): + bdata = be.read_h5ad(args.bdata_path) + if args.replicate_col not in bdata.samples.columns: + raise ValueError( + f"Specified --replicate-col `{args.replicate_col}` does not exist in ReporterScreen.samples.columns ({bdata.samples.columns}). Please check your input." + ) + if args.condition_col not in bdata.samples.columns: + raise ValueError( + f"Specified --condition-col `{args.condition_col}` does not exist in ReporterScreen.samples.columns ({bdata.samples.columns}). Please check your input." + ) + if not bdata.tiling and args.target_pos_col not in bdata.guides.columns: + raise ValueError( + f"Specified --target-pos-col `{args.target_pos_col}` does not exist in ReporterScreen.guides.columns ({bdata.guides.columns}). Please check your input." + ) + if args.posctrl_col and args.posctrl_col not in bdata.guides.columns: + raise ValueError( + f"Specified --posctrl-col `{args.posctrl_col}` does not exist in ReporterScreen.guides.columns ({bdata.guides.columns}). Please check your input." + ) + if ( + args.posctrl_col + and args.posctrl_val not in bdata.guides[args.posctrl_col].tolist() + ): + raise ValueError( + f"Specified --control-condition `{args.posctrl_val}` does not exist in ReporterScreen.guides[{args.posctrl_col}] ({bdata.guides[args.posctrl_col]}). Please check your input." + ) + if args.control_condition not in bdata.samples[args.condition_col].tolist(): + raise ValueError( + f"Specified --control-condition `{args.control_condition}` does not exist in ReporterScreen.samples[{args.condition_col}] :\n{bdata.samples[args.condition_col]}.\n Please check your input. \nFeed the condition where the editing rate would be quantified as the --control-condition argument, usually the condition with the least selection. (Closest to T0 for survival, pre-sort or bulk for sorting screens)." + ) + lfc_conds = args.lfc_conds.split(",") if not len(lfc_conds) == 2: raise ValueError( @@ -13,6 +43,14 @@ def check_args(args): ) args.lfc_cond1 = lfc_conds[0] args.lfc_cond2 = lfc_conds[1] + if args.lfc_cond1 not in bdata.samples[args.condition_col].tolist(): + raise ValueError( + f"Specified --lfc-conds `{args.lfc_cond1}` does not exist in ReporterScreen.samples[{args.condition_col}]:\n{bdata.samples[args.condition_col]}. Please check your input." + ) + if args.lfc_cond2 not in bdata.samples[args.condition_col].tolist(): + raise ValueError( + f"Specified --lfc-conds `{args.lfc_cond2}` does not exist in ReporterScreen.samples[{args.condition_col}]:\n{bdata.samples[args.condition_col]}. Please check your input." + ) if args.sample_covariates is not None: if "," in args.sample_covariates: args.sample_covariates = args.sample_covariates.split(",")