Skip to content

Commit

Permalink
Merge pull request #16 from pinellolab/dev
Browse files Browse the repository at this point in the history
v0.3.0
  • Loading branch information
jykr authored Dec 1, 2023
2 parents aa14578 + d7f16d4 commit 6c64277
Show file tree
Hide file tree
Showing 23 changed files with 2,473 additions and 673 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ jobs:
run: |
pip install pytest
pytest --sparse-ordering
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ This is an analysis toolkit for the pooled CRISPR reporter or sensor data. The r
4. [`bean-filter`](#bean-filter-filtering-and-optionally-translating-alleles): Filter reporter alleles; essential for `tiling` mode that allows for all alleles generated from gRNA.
5. [`bean-run`](#bean-run-quantify-variant-effects): Quantify targeted variants' effect sizes from screen data.

### Data structure
### Screen data is saved as *ReporterScreen* object in the pipeline.
BEAN stores mapped gRNA and allele counts in `ReporterScreen` object which is compatible with [AnnData](https://anndata.readthedocs.io/en/latest/index.html). See [Data Structure](#data-structure) section for more information.

### Examples
Expand Down Expand Up @@ -97,7 +97,7 @@ File should contain following columns with header.
* `R1_filepath`: Path to read 1 `.fastq[.gz]` file
* `R2_filepath`: Path to read 1 `.fastq[.gz]` file
* `sample_id`: ID of sequencing sample
* `rep [Optional]`: Replicate # of this sample
* `rep [Optional]`: Replicate # of this sample (Should NOT contain `.`)
* `bin [Optional]`: Name of the sorting bin
* `upper_quantile [Optional]`: FACS sorting upper quantile
* `lower_quantile [Optional]`: FACS sorting lower quantile
Expand Down Expand Up @@ -162,6 +162,11 @@ bean-create-screen gRNA_library.csv sample_list.csv gRNA_counts_table.csv
```bash
bean-profile my_sorting_screen.h5ad -o output_prefix `# Prefix for editing profile report`
```
### Output
Above command produces `prefix_editing_preference.[html,ipynb]` as editing preferences ([see example](notebooks/profile_editing_preference.ipynb)).

<img src="imgs/profile_output.png" alt="Allele translation" width="700" style="background-color:white;"/>

### Parameters
* `-o`, `--output-prefix` (default: `None`): Output prefix of editing pattern report (prefix.html, prefix.ipynb). If not provided, base name of `bdata_path` is used.
* `--replicate-col` (default: `"rep"`): Column name in `bdata.samples` that describes replicate ID.
Expand All @@ -170,8 +175,6 @@ bean-profile my_sorting_screen.h5ad -o output_prefix `# Prefix for editing profi
* `--control-condition` (default: `"bulk"`): Control condition where editing preference would be profiled at. Pre-filters data where `bdata.samples[condition_col] == control_condition`.
* `-w`, `--window-length` (default: `6`): Window length of editing window of maximal editing efficiency to be identified. This window is used to quantify context specificity within the window.

### Output
Above command produces `prefix_editing_preference.[html,ipynb]` as editing preferences ([see example](notebooks/profile_editing_preference.ipynb)).

<br/><br/>

Expand All @@ -183,7 +186,10 @@ bean-qc \
-r qc_report_my_sorting_screen `# Prefix for QC report`
```

`bean-qc` supports following quality control and masks samples with low quality. Specifically:
`bean-qc` supports following quality control and masks samples with low quality. Specifically:

<img src="imgs/qc_output.png" alt="Allele translation" width="900" style="background-color:white;"/>

* Plots guide coverage and the uniformity of coverage
* Guide count correlation between samples
* Log fold change correlation when positive controls are provided
Expand All @@ -201,6 +207,8 @@ Above command produces
* `--tiling` (default: `None`): If set as `True` or `False`, it sets the screen object to be tiling (`True`) or variant (`False`)-targeting screen when calculating editing rate.
* `--replicate-label` (default: `"rep"`): Label of column in `bdata.samples` that describes replicate ID.
* `--condition-label` (default: `"bin"`)": Label of column in `bdata.samples` that describes experimental condition. (sorting bin, time, etc.).
* `--sample-covariates` (default: `None`): Comma-separated list of column names in `bdata.samples` that describes non-selective experimental condition (drug treatment, etc.). The values in the `bdata.samples` should NOT contain `.`.
* `--no-editing` (default: `False`): Ignore QC about editing. Can be used for QC of other editing modalities.
* `--target-pos-col` (default: `"target_pos"`): Target position column in `bdata.guides` specifying target edit position in reporter.
* `--rel-pos-is-reporter` (default: `False`): Specifies whether `edit_start_pos` and `edit_end_pos` are relative to reporter position. If `False`, those are relative to spacer position.
* `--edit-start-pos` (default: `2`): Edit start position to quantify editing rate on, 0-based inclusive.
Expand All @@ -211,6 +219,7 @@ Above command produces
* `--posctrl-val` (default: `PosCtrl`): Value in .h5ad.guides[`posctrl_col`] that specifies guide will be used as the positive control in calculating log fold change.
* `--lfc-thres` (default: `0.1`): Positive guides' correlation threshold to filter out.
* `--lfc-conds` (default: `"top,bot"`): Values in of column in `ReporterScreen.samples[condition_label]` for LFC will be calculated between, delimited by comma
* `--ctrl-cond` (default: `"bulk"`): Value in of column in `ReporterScreen.samples[condition_label]` where guide-level editing rate to be calculated
* `--recalculate-edits` (default: `False`): Even when `ReporterScreen.layers['edit_count']` exists, recalculate the edit counts from `ReporterScreen.uns['allele_count']`."

<br/><br/>
Expand Down Expand Up @@ -278,7 +287,7 @@ bean-filter my_sorting_screen.h5ad \

## `bean-run`: Quantify variant effects
BEAN uses Bayesian network to incorporate gRNA editing outcome to provide posterior estimate of variant phenotype. The Bayesian network reflects data generation process. Briefly,
1. Cellular phenotype is modeled as the Gaussian mixture distribution of wild-type phenotype and variant phenotype.
1. Cellular phenotype (either for cells are sorted upon for sorting screen, or log(proliferation rate)) is modeled as the Gaussian mixture distribution of wild-type phenotype and variant phenotype.
2. The weight of the mixture components are inferred from the reporter editing outcome and the chromatin accessibility of the loci.
3. Cells with each gRNA, formulated as the mixture distribution, is sorted by the phenotypic quantile to produce the gRNA counts.

Expand All @@ -288,7 +297,7 @@ For the full detail, see the method section of the [BEAN manuscript](https://www

<br></br>
```bash
bean-run variant[tiling] my_sorting_screen_filtered.h5ad \
bean-run sorting[survival] variant[tiling] my_sorting_screen_filtered.h5ad \
[--uniform-edit, --scale-by-acc [--acc-bw-path accessibility_signal.bw, --acc-col accessibility]] \
-o output_prefix/ \
--fit-negctrl
Expand Down
20 changes: 17 additions & 3 deletions bean/framework/ReporterScreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(
self.layers["X_bcmatch"] = X_bcmatch
for k, df in self.uns.items():
if not isinstance(df, pd.DataFrame):
if k == "sample_covariates" and not isinstance(df, list):
self.uns[k] = df.tolist()
continue
if "guide" in df.columns and len(df) > 0:
if (
Expand Down Expand Up @@ -323,8 +325,20 @@ def __getitem__(self, index):
new_uns = deepcopy(self.uns)
for k, df in adata.uns.items():
if k.startswith("repguide_mask"):
new_uns[k] = df.loc[guides_include, adata.var.rep.unique()]
if "sample_covariates" in adata.uns:
adata.var["_rc"] = adata.var[
["rep"] + list(adata.uns["sample_covariates"])
].values.tolist()
adata.var["_rc"] = adata.var["_rc"].map(
lambda slist: ".".join(slist)
)
new_uns[k] = df.loc[guides_include, adata.var._rc.unique()]
#adata.var.pop("_rc")
else:
new_uns[k] = df.loc[guides_include, adata.var.rep.unique()]
if not isinstance(df, pd.DataFrame):
if k == "sample_covariates":
new_uns[k] = df
continue
if "guide" in df.columns:
if "allele" in df.columns:
Expand Down Expand Up @@ -892,7 +906,7 @@ def concat(screens: Collection[ReporterScreen], *args, axis=1, **kwargs):

if axis == 0:
for k in keys:
if k in ["target_base_change", "tiling"]:
if k in ["target_base_change", "tiling", "sample_covariates"]:
adata.uns[k] = screens[0].uns[k]
continue
elif "edit" not in k and "allele" not in k:
Expand All @@ -902,7 +916,7 @@ def concat(screens: Collection[ReporterScreen], *args, axis=1, **kwargs):
if axis == 1:
# If combining multiple samples, edit/allele tables should be merged.
for k in keys:
if k in ["target_base_change", "tiling"]:
if k in ["target_base_change", "tiling", "sample_covariates"]:
adata.uns[k] = screens[0].uns[k]
continue
elif "edit" not in k and "allele" not in k:
Expand Down
2 changes: 1 addition & 1 deletion bean/mapping/GuideEditCounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _count_reporter_edits(
)

def _get_guide_counts_bcmatch_semimatch(
self, bcmatch_layer="X_bcmatch", semimatch_layer="X"
self, bcmatch_layer="X_bcmatch", semimatch_layer="X_semimatch"
):
self.screen.layers[semimatch_layer] = np.zeros_like((self.screen.X))
R1_iter, R2_iter = self._get_fastq_iterators(
Expand Down
74 changes: 53 additions & 21 deletions bean/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,51 @@ def NormalModel(
sd = sd_alleles
sd = torch.repeat_interleave(sd, data.target_lengths, dim=0)
assert sd.shape == (data.n_guides, 1)

if hasattr(data, "sample_covariates"):
with pyro.plate("cov_place", data.n_sample_covariates):
mu_cov = pyro.sample("mu_cov", dist.Normal(0, 1))
assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape
with replicate_plate:
with bin_plate as b:
uq = data.upper_bounds[b]
lq = data.lower_bounds[b]
assert uq.shape == lq.shape == (data.n_condits,)
# with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)):
with guide_plate:
mu = (
mu.unsqueeze(0)
.unsqueeze(0)
.expand((data.n_reps, data.n_condits, -1, -1))
)
if hasattr(data, "sample_covariates"):
mu = mu + (data.rep_by_cov * mu_cov)[:, 0].unsqueeze(-1).unsqueeze(
-1
).unsqueeze(-1).expand((-1, data.n_condits, data.n_guides, 1))
sd = torch.sqrt(
(
sd.unsqueeze(0)
.unsqueeze(0)
.expand((data.n_reps, data.n_condits, -1, -1))
)
)
alleles_p_bin = get_std_normal_prob(
uq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)),
lq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)),
mu.unsqueeze(0).expand((data.n_condits, -1, -1)),
sd.unsqueeze(0).expand((data.n_condits, -1, -1)),
uq.unsqueeze(0)
.unsqueeze(-1)
.unsqueeze(-1)
.expand((data.n_reps, -1, data.n_guides, 1)),
lq.unsqueeze(0)
.unsqueeze(-1)
.unsqueeze(-1)
.expand((data.n_reps, -1, data.n_guides, 1)),
mu,
sd,
)
assert alleles_p_bin.shape == (data.n_condits, data.n_guides, 1)

expected_allele_p = alleles_p_bin.unsqueeze(0).expand(
data.n_reps, -1, -1, -1
)
expected_guide_p = expected_allele_p.sum(axis=-1)
assert alleles_p_bin.shape == (
data.n_reps,
data.n_condits,
data.n_guides,
1,
)
expected_guide_p = alleles_p_bin.sum(axis=-1)
assert expected_guide_p.shape == (
data.n_reps,
data.n_condits,
Expand All @@ -91,7 +116,6 @@ def NormalModel(
obs=data.X_masked.permute(0, 2, 1),
)
if use_bcmatch:
print(f"Use_bcmatch:{use_bcmatch}")
a_bcmatch = get_alpha(
expected_guide_p,
data.size_factor_bcmatch,
Expand Down Expand Up @@ -159,14 +183,10 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True):
with pyro.plate("guide_plate3", data.n_guides, dim=-1):
a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0)

assert (
data.X.shape
== data.X_bcmatch.shape
== (
data.n_reps,
data.n_condits,
data.n_guides,
)
assert data.X.shape == (
data.n_reps,
data.n_condits,
data.n_guides,
)
with poutine.mask(
mask=torch.logical_and(
Expand Down Expand Up @@ -491,6 +511,18 @@ def NormalGuide(data):
constraint=constraints.positive,
)
pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale))
if hasattr(data, "sample_covariates"):
with pyro.plate("cov_place", data.n_sample_covariates):
mu_cov_loc = pyro.param(
"mu_cov_loc", torch.zeros((data.n_sample_covariates,))
)
mu_cov_scale = pyro.param(
"mu_cov_scale",
torch.ones((data.n_sample_covariates,)),
constraint=constraints.positive,
)
mu_cov = pyro.sample("mu_cov", dist.Normal(mu_cov_loc, mu_cov_scale))
assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape


def MixtureNormalGuide(
Expand Down
87 changes: 71 additions & 16 deletions bean/model/readwrite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Sequence, Optional
from typing import Union, Sequence, Optional, List
import numpy as np
import pandas as pd
from statistics import NormalDist
Expand Down Expand Up @@ -57,42 +57,81 @@ def write_result_table(
adjust_confidence_negatives: np.ndarray = None,
guide_index: Optional[Sequence[str]] = None,
guide_acc: Optional[Sequence] = None,
sd_is_fitted: bool = True,
sample_covariates: List[str] = None,
return_result: bool = False,
) -> Union[pd.DataFrame, None]:
"""Combine target information and scores to write result table to a csv file or return it."""
if param_hist_dict["params"]["mu_loc"].dim() == 2:
mu = param_hist_dict["params"]["mu_loc"].detach()[:, 0].cpu().numpy()
mu_sd = param_hist_dict["params"]["mu_scale"].detach()[:, 0].cpu().numpy()
sd = param_hist_dict["params"]["sd_loc"].detach().exp()[:, 0].cpu().numpy()
if sd_is_fitted:
sd = param_hist_dict["params"]["sd_loc"].detach().exp()[:, 0].cpu().numpy()
elif param_hist_dict["params"]["mu_loc"].dim() == 1:
mu = param_hist_dict["params"]["mu_loc"].detach().cpu().numpy()
mu_sd = param_hist_dict["params"]["mu_scale"].detach().cpu().numpy()
sd = param_hist_dict["params"]["sd_loc"].detach().exp().cpu().numpy()
if sd_is_fitted:
sd = param_hist_dict["params"]["sd_loc"].detach().exp().cpu().numpy()
else:
raise ValueError(
f'`mu_loc` has invalid shape of {param_hist_dict["params"]["mu_loc"].shape}'
)
fit_df = pd.DataFrame(
{
"mu": mu,
"mu_sd": mu_sd,
"mu_z": mu / mu_sd,
"sd": sd,
}
)
param_dict = {
"mu": mu,
"mu_sd": mu_sd,
"mu_z": mu / mu_sd,
}
if sd_is_fitted:
param_dict["sd"] = sd
if sample_covariates is not None:
assert (
"mu_cov_loc" in param_hist_dict["params"]
and "mu_cov_scale" in param_hist_dict["params"]
), param_hist_dict["params"].keys()
for i, sample_cov in enumerate(sample_covariates):
param_dict[f"mu_{sample_cov}"] = (
mu + param_hist_dict["params"]["mu_cov_loc"].detach().cpu().numpy()[i]
)
param_dict[f"mu_sd_{sample_cov}"] = np.sqrt(
mu_sd**2
+ param_hist_dict["params"]["mu_cov_scale"].detach().cpu().numpy()[i]
** 2
)
param_dict[f"mu_z_{sample_cov}"] = (
param_dict[f"mu_{sample_cov}"] / param_dict[f"mu_sd_{sample_cov}"]
)

fit_df = pd.DataFrame(param_dict)
fit_df["novl"] = get_novl(fit_df, "mu", "mu_sd")
if "negctrl" in param_hist_dict.keys():
print("Normalizing with common negative control distribution")
mu0 = param_hist_dict["negctrl"]["params"]["mu_loc"].detach().cpu().numpy()
sd0 = (
param_hist_dict["negctrl"]["params"]["sd_loc"].detach().exp().cpu().numpy()
)
print(f"Fitted mu0={mu0}, sd0={sd0}.")
if sd_is_fitted:
sd0 = (
param_hist_dict["negctrl"]["params"]["sd_loc"]
.detach()
.exp()
.cpu()
.numpy()
)
print(f"Fitted mu0={mu0}" + (f", sd0={sd0}." if sd_is_fitted else ""))
fit_df["mu_scaled"] = (mu - mu0) / sd0
fit_df["mu_sd_scaled"] = mu_sd / sd0
fit_df["mu_z_scaled"] = fit_df.mu_scaled / fit_df.mu_sd_scaled
fit_df["sd_scaled"] = sd / sd0
if sd_is_fitted:
fit_df["sd_scaled"] = sd / sd0
fit_df["novl_scaled"] = get_novl(fit_df, "mu_scaled", "mu_sd_scaled")
if sample_covariates is not None:
for i, sample_cov in enumerate(sample_covariates):
fit_df[f"mu_{sample_cov}_scaled"] = (
fit_df[f"mu_{sample_cov}"] - mu0
) / sd0
fit_df[f"mu_sd_{sample_cov}_scaled"] = (
fit_df[f"mu_sd_{sample_cov}"] / sd0
)
fit_df[f"mu_z_{sample_cov}_scaled"] = (
fit_df[f"mu_{sample_cov}_scaled"] / fit_df["mu_sd_scaled"]
)

fit_df = pd.concat(
[target_info_df.reset_index(), fit_df.reset_index(drop=True)], axis=1
Expand Down Expand Up @@ -123,6 +162,22 @@ def write_result_table(
else "mu_sd",
)
fit_df = add_credible_interval(fit_df, "mu_adj", "mu_sd_adj")
if sample_covariates is not None:
for i, sample_cov in enumerate(sample_covariates):
fit_df = adjust_normal_params_by_control(
fit_df,
std,
suffix=f"_{sample_cov}_adj",
mu_adjusted_col=f"mu_{sample_cov}_scaled"
if "negctrl" in param_hist_dict.keys()
else f"mu_{sample_cov}",
mu_sd_adjusted_col=f"mu_sd_{sample_cov}_scaled"
if "negctrl" in param_hist_dict.keys()
else f"mu_sd_{sample_cov}",
)
fit_df = add_credible_interval(
fit_df, f"mu_{sample_cov}_adj", f"mu_sd_{sample_cov}_adj"
)

if write_fitted_eff or guide_acc is not None:
if "alpha_pi" not in param_hist_dict["params"].keys():
Expand Down
Loading

0 comments on commit 6c64277

Please sign in to comment.