diff --git a/bean/annotate/translate_allele.py b/bean/annotate/translate_allele.py index 2e81a89..a8fe178 100644 --- a/bean/annotate/translate_allele.py +++ b/bean/annotate/translate_allele.py @@ -595,7 +595,8 @@ def strsplit_edit(edit_str): def annotate_edit( edit_info: pd.DataFrame, - edit_col="edit", + edit_col: str = "edit", + control_tag: str = "CONTROL", splice_sites: Collection[ int ] = None, # TODO: may be needed to extended into multi-chromosome case @@ -604,6 +605,7 @@ def annotate_edit( Args edit_info: pd.DataFrame with at least 1 column of 'edit_col', which has 'Edit' format. + control_tag: String tag identifying non-targeting control guides so their variant signal are not aggregated. splice_sites: Collection of integer splice site positions. If the edit position matches the positions, it will be annotated as 'splicing'. """ @@ -619,8 +621,13 @@ def annotate_edit( edit_info.loc[ edit_info.pos.map(lambda s: not s.startswith("A")), "coding" ] = "noncoding" - edit_info.loc[edit_info.pos.map(lambda s: "CONTROL" in s), "group"] = "negctrl" - edit_info.loc[edit_info.pos.map(lambda s: "CONTROL" in s), "coding"] = "negctrl" + if control_tag is not None: + edit_info.loc[ + edit_info.pos.map(lambda s: control_tag in s), "group" + ] = "negctrl" + edit_info.loc[ + edit_info.pos.map(lambda s: control_tag in s), "coding" + ] = "negctrl" edit_info.loc[ (edit_info.coding == "noncoding") & (edit_info.group != "negctrl"), "int_pos" ] = edit_info.loc[ diff --git a/bean/model/model.py b/bean/model/model.py index 96514ce..779cfc1 100644 --- a/bean/model/model.py +++ b/bean/model/model.py @@ -206,18 +206,23 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): data.sample_mask, data.a0_bcmatch, ) - with poutine.mask( - mask=torch.logical_and( - data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) - > mask_thres, - data.repguide_mask, - ) - ): - pyro.sample( - "guide_bcmatch_counts", - dist.DirichletMultinomial(a_bcmatch, validate_args=False), - obs=data.X_bcmatch_masked.permute(0, 2, 1), - ) + try: + with poutine.mask( + mask=torch.logical_and( + data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) + > mask_thres, + data.repguide_mask, + ) + ): + pyro.sample( + "guide_bcmatch_counts", + dist.DirichletMultinomial(a_bcmatch, validate_args=False), + obs=data.X_bcmatch_masked.permute(0, 2, 1), + ) + except RuntimeError: + print(data.X_bcmatch_masked.shape) + print(data.repguide_mask.shape) + exit(1) return alleles_p_bin diff --git a/bean/model/run.py b/bean/model/run.py index 0cd7808..96f008f 100644 --- a/bean/model/run.py +++ b/bean/model/run.py @@ -293,7 +293,7 @@ def check_args(args, bdata): bdata.guides[args.negctrl_col].map(lambda s: s.lower()) == args.negctrl_col_value.lower() ).sum() - if not n_negctrl >= 20: + if not n_negctrl >= 10: raise ValueError( f"Not enough negative control guide in the input data: {n_negctrl}. Check your input arguments." ) diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 48f4347..a48be5a 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -1136,6 +1136,20 @@ def set_bcmatch(self, screen): ) self.a0_bcmatch = torch.as_tensor(a0_bcmatch) + def __getitem__(self, guide_idx): + ndata = super().__getitem__(guide_idx) + if hasattr(ndata, "X_bcmatch"): + ndata.X_bcmatch = ndata.X_bcmatch[:, :, guide_idx] + if hasattr(ndata, "X_bcmatch_masked"): + ndata.X_bcmatch_masked = ndata.X_bcmatch_masked[:, :, guide_idx] + if hasattr(ndata, "X_bcmatch_control"): + ndata.X_bcmatch_control = ndata.X_bcmatch_control[:, :, guide_idx] + if hasattr(ndata, "X_bcmatch_control_masked"): + ndata.X_bcmatch_control_masked = ndata.X_bcmatch_control_masked[ + :, :, guide_idx + ] + return ndata + @dataclass class VariantSortingReporterScreenData(VariantReporterScreenData, SortingScreenData): diff --git a/bin/bean-run b/bin/bean-run index 1aa08c4..84ef141 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -117,6 +117,7 @@ def main(args, bdata): pd.DataFrame(pd.Series(ndata.edit_index)) .reset_index() .rename(columns={"index": "edit"}), + control_tag=args.control_guide_tag, splice_sites=None if args.splice_site_path is None else splice_site, ) target_info_df["effective_edit_rate"] = _obtain_effective_edit_rate(ndata).cpu() diff --git a/tests/test_run.py b/tests/test_run.py index 00bca37..3be6fea 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -4,7 +4,7 @@ @pytest.mark.order(13) def test_run_variant_wacc(): - cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal_chr6.bw -o tests/test_res/var/ --repguide-mask None" + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal_chr6.bw -o tests/test_res/var/ --repguide-mask None --n-iter 10" try: subprocess.check_output( cmd, @@ -17,7 +17,7 @@ def test_run_variant_wacc(): @pytest.mark.order(14) def test_run_variant_noacc(): - cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ " + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --n-iter 10" try: subprocess.check_output( cmd, @@ -30,7 +30,7 @@ def test_run_variant_noacc(): @pytest.mark.order(15) def test_run_variant_wo_negctrl_uniform(): - cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --uniform-edit " + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --uniform-edit --n-iter 10" try: subprocess.check_output( cmd, @@ -42,8 +42,8 @@ def test_run_variant_wo_negctrl_uniform(): @pytest.mark.order(16) -def test_run_tiling_wo_negctrl(): - cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal.bw -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" +def test_run_variant_wacc_negctrl(): + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal_chr6.bw -o tests/test_res/var/ --repguide-mask None --n-iter 10 --fit-negctrl " try: subprocess.check_output( cmd, @@ -55,8 +55,8 @@ def test_run_tiling_wo_negctrl(): @pytest.mark.order(17) -def test_run_tiling_with_wo_negctrl_noacc(): - cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" +def test_run_variant_noacc_negctrl(): + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --fit-negctrl --n-iter 10" try: subprocess.check_output( cmd, @@ -68,8 +68,47 @@ def test_run_tiling_with_wo_negctrl_noacc(): @pytest.mark.order(18) +def test_run_variant_uniform_negctrl(): + cmd = "bean-run sorting variant tests/data/var_mini_screen_annotated.h5ad -o tests/test_res/var/ --uniform-edit --fit-negctrl --n-iter 10" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +@pytest.mark.order(19) +def test_run_tiling_wo_negctrl(): + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal.bw -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None --n-iter 10" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +@pytest.mark.order(20) +def test_run_tiling_with_wo_negctrl_noacc(): + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None --n-iter 10" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +@pytest.mark.order(21) def test_run_tiling_with_wo_negctrl_uniform(): - cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None" + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None --n-iter 10" try: subprocess.check_output( cmd, @@ -78,3 +117,45 @@ def test_run_tiling_with_wo_negctrl_uniform(): ) except subprocess.CalledProcessError as exc: raise exc + + +@pytest.mark.order(22) +def test_run_tiling_negctrl(): + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad --scale-by-acc --acc-bw-path tests/data/accessibility_signal.bw -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +@pytest.mark.order(23) +def test_run_tiling_with_negctrl_noacc(): + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +@pytest.mark.order(24) +def test_run_tiling_with_negctrl_uniform(): + cmd = "bean-run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_A.G_translated_prop0.1_0.3 --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +# Add fit_negctrl examples