Skip to content

Commit

Permalink
allow missing samples in qc
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Sep 20, 2023
1 parent 8838653 commit 3d7b8cd
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 12 deletions.
17 changes: 10 additions & 7 deletions bean/framework/ReporterScreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,13 @@ def get_edit_mat_from_uns(
+ "Call .get_edit_from_allele(allele_count_key, allele_key)"
)
edits = self.uns[edit_count_key].copy()
if self.layers["edits"] is not None:
if "edits" in self.layers:
old_edits = self.layers["edits"].copy()
self.layers["edits"] = np.zeros_like(
self.X,
)
else:
old_edits = None
self.layers["edits"] = np.zeros_like(
self.X,
).astype(float)
edits["ref_base"] = edits.edit.map(lambda e: e.ref_base)
edits["alt_base"] = edits.edit.map(lambda e: e.alt_base)
edits = edits.loc[
Expand Down Expand Up @@ -487,18 +487,20 @@ def get_guide_edit_rate(

if prior_weight is None:
prior_weight = 1
n_edits = self.layers[edit_layer][:, bulk_idx].sum(axis=1)
n_counts = self.layers[count_layer][:, bulk_idx].sum(axis=1)
n_edits = self.layers[edit_layer].copy()[:, bulk_idx].sum(axis=1)
n_counts = self.layers[count_layer].copy()[:, bulk_idx].sum(axis=1)
edit_rate = (n_edits + prior_weight / 2) / (
(n_counts * num_targetable_sites) + prior_weight / 2
)
edit_rate[n_counts < bcmatch_thres] = np.nan
if normalize_by_editable_base:
print("normalize by editable counts")
edit_rate[num_targetable_sites == 0] = np.nan
if return_result:
return edit_rate
else:
self.guides["edit_rate"] = edit_rate
print(self.guides.edit_rate)

def get_edit_rate(
self,
Expand Down Expand Up @@ -630,7 +632,8 @@ def filter_allele_counts_by_pos(
)
)
else:
if not 'guide_len' in self.guides.columns.tolist(): self.guides['guide_len'] = self.guides.sequence.map(len)
if not "guide_len" in self.guides.columns.tolist():
self.guides["guide_len"] = self.guides.sequence.map(len)
guide_start_pos = (
32 - 6 - self.guides.loc[allele_count_df.guide, "guide_len"].values
)
Expand Down
3 changes: 3 additions & 0 deletions bean/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def NormalModel(
with replicate_plate2:
with pyro.plate("guide_plate3", data.n_guides, dim=-1):
a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0)
print(a)
print(a.max())
print(a.min())
a_bcmatch = get_alpha(
expected_guide_p,
data.size_factor_bcmatch,
Expand Down
4 changes: 2 additions & 2 deletions bean/qc/sample_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def set_sample_edit_rates(
)
if agg_method == "median":
bdata.samples[f"{agg_method}_editing_rate"] = np.nanmedian(
bdata.layers["edit_rate"], axis=0
bdata.layers["edit_rate"].copy(), axis=0
)
if agg_method == "mean":
bdata.samples[f"{agg_method}_editing_rate"] = np.nanmean(
bdata.layers["edit_rate"], axis=0
bdata.layers["edit_rate"].copy(), axis=0
)
65 changes: 65 additions & 0 deletions bean/qc/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np
import pandas as pd
from copy import deepcopy
import argparse
from ..framework.ReporterScreen import ReporterScreen, concat


def parse_args():
Expand All @@ -23,6 +27,12 @@ def parse_args():
help="Path where quality-filtered ReporterScreen object to be written to",
type=str,
)
parser.add_argument(
"-i",
"--ignore-missing-samples",
help="If the flag is not provided, if the ReporterScreen object does not contain all condiitons for each replicate, make fake empty samples. If the flag is provided, don't add dummy samples.",
action="store_true",
)
parser.add_argument(
"-r",
"--out-report-prefix",
Expand Down Expand Up @@ -99,3 +109,58 @@ def parse_args():
if args.out_report_prefix is None:
args.out_report_prefix = f"{args.bdata_path.rsplit('.h5ad', 1)[0]}.qc_report"
return args


def _add_dummy_sample(bdata, rep, cond, condition_label: str, replicate_label: str):
sample_id = f"{rep}_{cond}"
cond_df = deepcopy(bdata.samples)
cond_df[replicate_label] = np.nan
cond_df = cond_df.drop_duplicates()
cond_row = cond_df.loc[cond_df[condition_label] == cond, :]
if not len(cond_row) == 1:
raise ValueError(
f"Non-unique condition specification in ReporterScreen.samples: {cond_row}"
)
cond_row.index = [sample_id]
cond_row.loc[:, replicate_label] = rep
dummy_sample_bdata = ReporterScreen(
X=np.zeros((bdata.n_obs, 1)),
X_bcmatch=np.zeros((bdata.n_obs, 1)),
guides=bdata.guides,
samples=cond_row,
)
for k in bdata.uns.keys():
if isinstance(bdata.uns[k], pd.DataFrame):
dummy_sample_bdata.uns[k] = pd.DataFrame(
columns=bdata.uns[k].columns.tolist()[:2] + [sample_id]
)
else:
dummy_sample_bdata.uns[k] = bdata.uns[k]
bdata = concat([bdata, dummy_sample_bdata])
return bdata


def fill_in_missing_samples(bdata, condition_label: str, replicate_label: str):
"""If not all condition exists for every replicate in bdata, fill in fake sample"""
added_dummy = False
for rep in bdata.samples[replicate_label].unique():
for cond in bdata.samples[condition_label].unique():
if (
len(
np.where(
(bdata.samples[replicate_label] == rep)
& (bdata.samples[condition_label] == cond)
)[0]
)
!= 1
):
bdata = _add_dummy_sample(
bdata, rep, cond, condition_label, replicate_label
)
if not added_dummy:
added_dummy = True
if added_dummy:
bdata = bdata[
:, bdata.samples.sort_values([replicate_label, condition_label]).index
]
return bdata
4 changes: 2 additions & 2 deletions bin/bean-qc
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import os
import papermill as pm
import bean as be
from bean.qc.utils import parse_args

from bean.qc.utils import parse_args, check_args


def main():
args = parse_args()
check_args(args)
os.system(
"python -m ipykernel install --user --name bean_python3 --display-name bean_python3"
)
Expand Down
1 change: 1 addition & 0 deletions bin/bean-run
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def main(args, bdata):
shrink_alpha=args.shrink_alpha,
replicate_col=args.replicate_col,
)
print(ndata.a0)
adj_negctrl_idx = None
if args.mode == "variant":
if "edit_rate" not in bdata.guides.columns:
Expand Down
35 changes: 35 additions & 0 deletions notebooks/sample_quality_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"source": [
"import perturb_tools as pt\n",
"import bean as be\n",
"from bean.qc.utils import fill_in_missing_samples\n",
"import matplotlib.pyplot as plt\n",
"plt.style.use('default')"
]
Expand Down Expand Up @@ -64,6 +65,40 @@
"bdata = be.read_h5ad(bdata_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Add dummy samples if not paired"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bdata = fill_in_missing_samples(bdata, condition_label, replicate_label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bdata.samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bdata.X"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess


@pytest.mark.order(5)
@pytest.mark.order(7)
def test_filter_varscreen():
cmd = "bean-filter tests/data/var_mini_screen_masked.h5ad -o tests/data/var_mini_screen_annotated -s 0 -e 19 -w -b -t -ap 0.1 -sp 0.3"
try:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,29 @@ def test_qc_tiling():
)
except subprocess.CalledProcessError as exc:
raise exc


@pytest.mark.order(5)
def test_dummy_insertion_varscreen():
cmd = "bean-qc tests/data/var_mini_screen_missing.h5ad -o tests/data/var_mini_screen_missing_masked.h5ad -r tests/test_res/qc_report_var_mini_screen_missing --count-correlation-thres 0.6"
try:
subprocess.check_output(
cmd,
shell=True,
universal_newlines=True,
)
except subprocess.CalledProcessError as exc:
raise exc


@pytest.mark.order(6)
def test_dummy_insertion_tilingscreen():
cmd = "bean-qc tests/data/tiling_mini_screen_missing.h5ad -o tests/data/tiling_mini_screen_missing_masked.h5ad -r tests/test_res/qc_report_tiling_mini_screen_missing --count-correlation-thres 0.6"
try:
subprocess.check_output(
cmd,
shell=True,
universal_newlines=True,
)
except subprocess.CalledProcessError as exc:
raise exc

0 comments on commit 3d7b8cd

Please sign in to comment.