diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml deleted file mode 100644 index 8ef81a4..0000000 --- a/.github/workflows/pypi_release.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: Publish to PyPI - -on: - release: - types: [published] - -jobs: - deploy: - name: Build and publish to PyPI - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/crispr-bean - permissions: - id-token: write - steps: - - uses: actions/checkout@main - - name: Set up Python 3.x - uses: actions/setup-python@v3 - with: - python-version: '3.x' - - name: Install pypa/setuptools - run: >- - python -m pip install --upgrade pip - pip install wheel numpy setuptools Cython auditwheel twine - - name: Extract tag name - id: tag - run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3) - - name: Build - run: | - python setup.py sdist bdist_wheel - auditwheel repair dist/crispr_bean-${{ steps.tag.outputs.TAG_NAME }}-cp*-cp*-linux_x86_64.whl --plat manylinux_2_24_x86_64 - mv wheelhouse/* dist - rm dist/*-cp*-cp*-linux_x86_64.whl - - name: Test build - run: | - pip install dist/*.whl - pytest --sparse-ordering - - name: Publish - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: - twine upload dist/* \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index bf8155c..9b11cdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,9 @@ # Changelog -## 1.2.5 -* Allow `bean run .. tiling` for untranslated `--allele-df-key`. - +## 1.2.8 +* Change .pyx files to be compatible with more recent numpy versions +## 1.2.7 +* **CRITICAL** Fix sample ordering & masking issue for survival screens ## 1.2.6 * Fix overflow in `bean run survival` and autograde error related to inplace assignment for `bean run survival tiling`. - -## 1.2.7 -* **CRITICAL** Fix sample ordering & masking issue for survival screens \ No newline at end of file +## 1.2.5 +* Allow `bean run .. tiling` for untranslated `--allele-df-key`. \ No newline at end of file diff --git a/README.md b/README.md index 6f6760d..d189358 100755 --- a/README.md +++ b/README.md @@ -22,9 +22,10 @@ 2. [`profile`](https://pinellolab.github.io/crispr-bean/profile.html): Profile editing preferences of your editor. 3. [`qc`](https://pinellolab.github.io/crispr-bean/qc.html): Quality control report and filtering out / masking of aberrant sample and guides 4. [`filter`](https://pinellolab.github.io/crispr-bean/filter.html): Filter reporter alleles; essential for `tiling` mode that allows for all alleles generated from gRNA. -5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. +5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. **See more about the model in the link**. * Screen data is saved as [`ReporterScreen` object](https://pinellolab.github.io/crispr-bean/reporterscreen.html) in the pipeline. BEAN stores mapped gRNA and allele counts in `ReporterScreen` object which is compatible with [AnnData](https://anndata.readthedocs.io/en/latest/index.html). + ## Installation First install [PyTorch](https://pytorch.org/get-started/). Then download from PyPI: @@ -50,6 +51,7 @@ See the [documentation](https://pinellolab.github.io/crispr-bean/) for tutorials | GWAS variant library | Survival / Proliferation | Yes/No | [GWAS variant screen](https://pinellolab.github.io/crispr-bean/tutorial_prolif_gwas.html) | Coding sequence tiling libarary | Survival / Proliferation | Yes/No | [Coding sequence tiling screen](https://pinellolab.github.io/crispr-bean/tutorial_prolif_cds.html) | Perturbation library without reporter | FACS sorting | No | [No reporter screen](https://pinellolab.github.io/crispr-bean/tutorial_no_edit.html) +| Integration of disjoint libraries | Any | Any | [Feeding custom prior](https://pinellolab.github.io/crispr-bean/tutorial_custom_prior.html) Also see notebook that visualizes screen analysis result [here](https://github.com/pinellolab/crispr-bean/blob/main/docs/visualize_var.ipynb). diff --git a/bean/cli/build_prior.py b/bean/cli/build_prior.py new file mode 100644 index 0000000..b9f3475 --- /dev/null +++ b/bean/cli/build_prior.py @@ -0,0 +1,72 @@ +import pickle as pkl +import numpy as np +import torch +from bean.model.run import _get_guide_target_info +from bean.model.parser import parse_args +from bean.cli.run import main as get_screendata +from bean.preprocessing.data_class import SortingScreenData + + +def generate_prior_data_for_disjoint_library_pair( + command1: str, command2: str, output1_path: str, prior_params_path: str +): + """Generate prior for a two batches with disjoint guides but with shared variants.""" + with open(output1_path, "rb") as f: + data = pkl.load(f) + ndata = data["data"] + parser = parse_args() + command1 = command1.split("bean run ")[-1] + command2 = command2.split("bean run ")[-1] + args = parser.parse_args(command1.split(" ")) + args2 = parser.parse_args(command2.split(" ")) + ndata2 = get_screendata(args2, return_data=True) + target_df = _get_guide_target_info( + ndata.screen, args, cols_include=[args.negctrl_col] + ) + target_df2 = _get_guide_target_info( + ndata2.screen, args2, cols_include=[args2.negctrl_col] + ) + batch1_idx = np.where( + target_df.index.map(lambda s: s in target_df2.index.tolist()) + )[0] + batch2_idx = [] + for i in batch1_idx: + batch2_idx.append( + np.where(target_df.index.tolist()[i] == target_df2.index)[0].item() + ) + batch2_idx = np.array(batch2_idx) + if isinstance(ndata, SortingScreenData): + mu_loc = torch.zeros((ndata2.n_targets, 1)) + mu_loc[batch2_idx, :] = data["params"]["mu_loc"][batch1_idx, :] + mu_scale = torch.ones((ndata2.n_targets, 1)) + mu_scale[batch2_idx, :] = data["params"]["mu_scale"][batch1_idx, :] + sd_loc = torch.zeros((ndata2.n_targets, 1)) + sd_loc[batch2_idx, :] = data["params"]["sd_loc"][batch1_idx, :] + sd_scale = torch.ones((ndata2.n_targets, 1)) * 0.01 + sd_scale[batch2_idx, :] = data["params"]["sd_scale"][batch1_idx, :] + prior_params = { + "mu_loc": mu_loc, + "mu_scale": mu_scale, + "sd_loc": sd_loc, + "sd_scale": sd_scale, + } + else: + mu_loc = torch.zeros((ndata2.n_targets, 1)) + mu_loc[batch2_idx, :] = data["params"]["mu_loc"][batch1_idx, :] + mu_scale = torch.ones((ndata2.n_targets, 1)) + mu_scale[batch2_idx, :] = data["params"]["mu_scale"][batch1_idx, :] + prior_params = { + "mu_loc": mu_loc, + "mu_scale": mu_scale, + } + with open(prior_params_path, "wb") as f: + pkl.dump(prior_params, f) + print( + f"Successfully generated prior parameters at {prior_params_path}. To use this parameter, run:\nbean run {command2+' --prior-params '+prior_params_path}" + ) + + +def main(args): + generate_prior_data_for_disjoint_library_pair( + args.command1, args.command2, args.raw_run_output1, args.output_path + ) diff --git a/bean/cli/execute.py b/bean/cli/execute.py index 5917bd9..cf23daf 100755 --- a/bean/cli/execute.py +++ b/bean/cli/execute.py @@ -7,6 +7,7 @@ from bean.model.parser import parse_args as get_run_parser from bean.framework.parser import get_input_parser as get_create_screen_parser from bean.annotate.utils import get_splice_parser as get_splice_site_parser +from bean.model.parser_prior import parse_args as get_prior_parser from bean.cli.count import main as count from bean.cli.count_samples import main as count_samples from bean.cli.profile import main as profile @@ -15,6 +16,15 @@ from bean.cli.run import main as run from bean.cli.create_screen import main as create_screen from bean.cli.get_splice_sites import main as get_splice_sites +from bean.cli.build_prior import main as build_prior + +import warnings + +warnings.filterwarnings( + action="ignore", + category=FutureWarning, + message=r".*The default of observed=False is deprecated and will be changed to True in a future version of pandas.*", +) def get_parser(): @@ -40,6 +50,10 @@ def get_parser(): "get-splice-sites", help="get splice sites" ) splice_site_parser = get_splice_site_parser(splice_site_parser) + prior_parser = subparsers.add_parser( + "build-prior", help="obtain prior_params.pkl for batched runs" + ) + prior_parser = get_prior_parser(prior_parser) return parser @@ -65,5 +79,7 @@ def main() -> None: create_screen(args) elif args.subcommand == "get-splice-sites": get_splice_sites(args) + elif args.subcommand == "build-prior": + build_prior(args) else: parser.print_help() diff --git a/bean/cli/run.py b/bean/cli/run.py index d39a1dd..0195965 100755 --- a/bean/cli/run.py +++ b/bean/cli/run.py @@ -31,6 +31,7 @@ check_args, identify_model_guide, identify_negctrl_model_guide, + _check_prior_params, ) logging.basicConfig( @@ -60,7 +61,7 @@ ) -def main(args): +def main(args, return_data=False): print( r""" _ _ @@ -114,7 +115,8 @@ def main(args): replicate_col=args.replicate_col, use_bcmatch=(not args.ignore_bcmatch), ) - + if return_data: + return ndata # Build variant dataframe adj_negctrl_idx = None if args.library_design == "variant": @@ -183,6 +185,11 @@ def main(args): ) guide_info_df = ndata.screen.guides + # Add user-defined prior. + if args.prior_params is not None: + prior_params = _check_prior_params(args.prior_params, ndata) + model = partial(model, prior_params=prior_params) + # Run the inference steps info(f"Running inference for {model_label}...") if args.load_existing: @@ -211,6 +218,8 @@ def main(args): ) else: param_history_dict_negctrl = None + save_dict["data"] = ndata + # Save results outfile_path = ( f"{prefix}/bean_element[sgRNA]_result.{model_label}{args.result_suffix}.csv" @@ -218,8 +227,11 @@ def main(args): info(f"Done running inference. Writing result at {outfile_path}...") if not os.path.exists(prefix): os.makedirs(prefix) - with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle: - pkl.dump(save_dict, handle) + if args.save_raw: + with open( + f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb" + ) as handle: + pkl.dump(save_dict, handle) write_result_table( target_info_df, param_history_dict, diff --git a/bean/mapping/utils.py b/bean/mapping/utils.py index 6259c2b..8f220ee 100755 --- a/bean/mapping/utils.py +++ b/bean/mapping/utils.py @@ -321,6 +321,7 @@ def _check_arguments(args, info_logger, warn_logger, error_logger): sgRNA_info_tbl = pd.read_csv(args.sgRNA_filename) def _check_sgrna_info_table(args, sgRNA_info_tbl): + # Check column names if args.offset: if "offset" not in sgRNA_info_tbl.columns: raise InputFileError( @@ -345,6 +346,10 @@ def _check_sgrna_info_table(args, sgRNA_info_tbl): raise InputFileError( f"Offset option is set but the input file doesn't contain the `reporter` column: Provided {sgRNA_info_tbl.columns}" ) + if sgRNA_info_tbl["name"].duplicated().any(): + raise InputFileError( + f"Duplicate guide names: {sgRNA_info_tbl.loc[sgRNA_info_tbl['name'].duplicated(),:].index}. Please provide unique IDs for each guide." + ) _check_sgrna_info_table(args, sgRNA_info_tbl) diff --git a/bean/model/model.py b/bean/model/model.py index 779cfc1..cbf4df2 100755 --- a/bean/model/model.py +++ b/bean/model/model.py @@ -1,3 +1,4 @@ +from typing import Optional import torch import pyro from pyro import poutine @@ -16,33 +17,57 @@ def NormalModel( - data: VariantSortingScreenData, mask_thres: int = 10, use_bcmatch: bool = True + data: VariantSortingScreenData, + mask_thres: int = 10, + use_bcmatch: bool = True, + sd_scale: float = 0.01, + prior_params: Optional[dict] = None, ): """ Fit only on guide counts - Args - -- - data: input data + + Args: + data: input data + mask_thres: threshold for masking guide counts for stability. Defaults to 10. + use_bcmatch: whether to use barcode-matched counts. Defaults to True. + sd_scale: scale for prior standard deviation. Defaults to 0.01 for improved identifiability. + prior_params: prior parameters. If provided, specified prior parameters will be used. """ replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2) bin_plate = pyro.plate("bin_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + sd_loc = torch.zeros((data.n_targets, 1)) + sd_scale = torch.ones((data.n_targets, 1)) * sd_scale + mu_dist = dist.Laplace(0, 1) + if prior_params is not None: + if "sd_loc" in prior_params: + sd_loc = prior_params["sd_loc"] + if "sd_scale" in prior_params: + sd_scale = prior_params["sd_scale"] + if "mu_loc" in prior_params or "mu_scale" in prior_params: + mu_loc = 0.0 + mu_scale = 1.0 + if "mu_loc" in prior_params: + mu_loc = prior_params["mu_loc"] + if "mu_scale" in prior_params: + mu_scale = prior_params["mu_scale"] + mu_dist = dist.Normal(mu_loc, mu_scale) + # Set the prior for phenotype means with pyro.plate("guide_plate0", 1): with pyro.plate("guide_plate1", data.n_targets): - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) - sd_alleles = pyro.sample( - "sd_alleles", - dist.LogNormal( - torch.zeros((data.n_targets, 1)), torch.ones(data.n_targets, 1) - ), + mu_targets = pyro.sample("mu_targets", mu_dist) + sd_targets = pyro.sample( + "sd_targets", + dist.LogNormal(sd_loc, sd_scale), ) - mu_center = mu_alleles + + mu_center = mu_targets mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) assert mu.shape == (data.n_guides, 1) - sd = sd_alleles + sd = sd_targets sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) assert sd.shape == (data.n_guides, 1) if hasattr(data, "sample_covariates"): @@ -150,10 +175,10 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) # Set the prior for phenotype means - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) - sd_alleles = pyro.sample("sd_alleles", dist.LogNormal(0, 1)) - mu = mu_alleles.repeat(data.n_guides).unsqueeze(-1) - sd = sd_alleles.repeat(data.n_guides).unsqueeze(-1) + mu_targets = pyro.sample("mu_targets", dist.Laplace(0, 1)) + sd_targets = pyro.sample("sd_targets", dist.LogNormal(0, 1)) + mu = mu_targets.repeat(data.n_guides).unsqueeze(-1) + sd = sd_targets.repeat(data.n_guides).unsqueeze(-1) with replicate_plate: with bin_plate as b: @@ -246,19 +271,19 @@ def MixtureNormalConstPiModel( # Set the prior for phenotype means with pyro.plate("guide_plate0", 1): with pyro.plate("guide_plate1", data.n_targets): - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) - sd_alleles = pyro.sample( - "sd_alleles", + mu_targets = pyro.sample("mu_targets", dist.Laplace(0, 1)) + sd_targets = pyro.sample( + "sd_targets", dist.LogNormal( torch.zeros((data.n_targets, 1)), torch.ones(data.n_targets, 1) * sd_scale, ), ) - mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_alleles], axis=-1) + mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) assert mu.shape == (data.n_guides, 2) - sd = torch.cat([torch.ones((data.n_targets, 1)), sd_alleles], axis=-1) + sd = torch.cat([torch.ones((data.n_targets, 1)), sd_targets], axis=-1) sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) assert sd.shape == (data.n_guides, 2) # The pi should be Dirichlet distributed instead of independent betas @@ -357,10 +382,19 @@ def MixtureNormalModel( sd_scale: float = 0.01, scale_by_accessibility: bool = False, fit_noise: bool = False, + prior_params: Optional[dict] = None, ): """ + Using the reporter outcome, phenotype of cells with a guide will be modeled as mixture of two normal distributions of edited and unedited cells. + Args: - scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility + data: Input data of type VariantSortingReporterScreenData. + alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1. + use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True. + sd_scale: Scale for the prior standard deviation. Defaults to 0.01. + scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility. + fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise. + prior_params: Optional dictionary of prior parameters. If provided, specified prior parameters will be used. """ torch.autograd.set_detect_anomaly(True) replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) @@ -368,22 +402,36 @@ def MixtureNormalModel( bin_plate = pyro.plate("bin_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + sd_loc = torch.zeros((data.n_targets, 1)) + sd_scale = torch.ones((data.n_targets, 1)) * sd_scale + mu_dist = dist.Laplace(0, 1) + if prior_params is not None: + if "sd_loc" in prior_params: + sd_loc = prior_params["sd_loc"] + if "sd_scale" in prior_params: + sd_scale = prior_params["sd_scale"] + if "mu_loc" in prior_params or "mu_scale" in prior_params: + mu_loc = 0.0 + mu_scale = 1.0 + if "mu_loc" in prior_params: + mu_loc = prior_params["mu_loc"] + if "mu_scale" in prior_params: + mu_scale = prior_params["mu_scale"] + mu_dist = dist.Normal(mu_loc, mu_scale) # Set the prior for phenotype means with pyro.plate("guide_plate0", 1): with pyro.plate("guide_plate1", data.n_targets): - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) - sd_alleles = pyro.sample( - "sd_alleles", - dist.LogNormal( - torch.zeros((data.n_targets, 1)), - torch.ones(data.n_targets, 1) * sd_scale, - ), + mu_targets = pyro.sample("mu_targets", mu_dist) + sd_targets = pyro.sample( + "sd_targets", + dist.LogNormal(sd_loc, sd_scale), ) - mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_alleles], axis=-1) + + mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) assert mu.shape == (data.n_guides, 2) - sd = torch.cat([torch.ones((data.n_targets, 1)), sd_alleles], axis=-1) + sd = torch.cat([torch.ones((data.n_targets, 1)), sd_targets], axis=-1) sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) assert sd.shape == (data.n_guides, 2) # The pi should be Dirichlet distributed instead of independent betas @@ -499,174 +547,82 @@ def MixtureNormalModel( ) -def NormalGuide(data): - with pyro.plate("guide_plate0", 1): - with pyro.plate("guide_plate1", data.n_targets): - mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) - mu_scale = pyro.param( - "mu_scale", - torch.ones((data.n_targets, 1)), - constraint=constraints.positive, - ) - pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - sd_loc = pyro.param("sd_loc", torch.zeros((data.n_targets, 1))) - sd_scale = pyro.param( - "sd_scale", - torch.ones((data.n_targets, 1)), - constraint=constraints.positive, - ) - pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale)) - if hasattr(data, "sample_covariates"): - with pyro.plate("cov_place", data.n_sample_covariates): - mu_cov_loc = pyro.param( - "mu_cov_loc", torch.zeros((data.n_sample_covariates,)) - ) - mu_cov_scale = pyro.param( - "mu_cov_scale", - torch.ones((data.n_sample_covariates,)), - constraint=constraints.positive, - ) - mu_cov = pyro.sample("mu_cov", dist.Normal(mu_cov_loc, mu_cov_scale)) - assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape - - -def MixtureNormalGuide( - data, - alpha_prior: float = 1, - use_bcmatch: bool = True, - scale_by_accessibility: bool = False, - fit_noise: bool = False, -): - """ - Guide for MixtureNormal model - """ - - replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) - guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) - - # Set the prior for phenotype means - mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) - mu_scale = pyro.param( - "mu_scale", torch.ones((data.n_targets, 1)), constraint=constraints.positive - ) - sd_loc = pyro.param("sd_loc", torch.zeros((data.n_targets, 1))) - sd_scale = pyro.param( - "sd_scale", torch.ones((data.n_targets, 1)), constraint=constraints.positive - ) - with pyro.plate("guide_plate0", 1): - with pyro.plate("guide_plate1", data.n_targets): - mu_alleles = pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - sd_alleles = pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale)) - mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_alleles], axis=-1) - mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) - assert mu.shape == (data.n_guides, 2) - - sd = torch.cat([torch.ones((data.n_targets, 1)), sd_alleles], axis=-1) - sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) - assert sd.shape == (data.n_guides, 2) - # The pi should be Dirichlet distributed instead of independent betas - alpha_pi = pyro.param( - "alpha_pi", - torch.ones( - ( - data.n_guides, - 2, - ) - ) - * alpha_prior, - constraint=constraints.positive, - ) - assert alpha_pi.shape == ( - data.n_guides, - 2, - ), alpha_pi.shape - pi_a_scaled = alpha_pi / alpha_pi.sum(axis=-1)[:, None] * data.pi_a0[:, None] - - with replicate_plate: - with guide_plate: - pi = pyro.sample( - "pi", - dist.Dirichlet( - pi_a_scaled.unsqueeze(0) - .unsqueeze(0) - .expand(data.n_reps, 1, -1, -1) - .clamp(1e-5) - ), - ) - assert pi.shape == ( - data.n_reps, - 1, - data.n_guides, - 2, - ), pi.shape - if scale_by_accessibility: - # Endogenous target site editing rate may be different - pi = scale_pi_by_accessibility( - pi, data.guide_accessibility, fit_noise=fit_noise - ) - - -def ControlNormalGuide(data, mask_thres=10, use_bcmatch=True): - """ - Fit shared mean - """ - # Set the prior for phenotype means - mu_loc = pyro.param("mu_loc", torch.tensor(0.0)) - mu_scale = pyro.param( - "mu_scale", torch.tensor(1.0), constraint=constraints.positive - ) - sd_loc = pyro.param("sd_loc", torch.tensor(0.0)) - sd_scale = pyro.param( - "sd_scale", torch.tensor(1.0), constraint=constraints.positive - ) - pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale)) - - def MultiMixtureNormalModel( data: TilingSortingReporterScreenData, alpha_prior=1, use_bcmatch=True, sd_scale=0.01, - norm_pi=False, scale_by_accessibility=False, - epsilon=1e-5, fit_noise: bool = False, + prior_params: Optional[dict] = None, + epsilon=1e-5, ): - """Tiling version of MixtureNormalModel""" + """ + Using the reporter outcome, phenotype of cells with a guide will be modeled as mixture of normal distributions of all major alleles (including WT) produced by the guide. + + Args: + data: Input data of type VariantSortingReporterScreenData. + alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1. + use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True. + sd_scale: Scale for the prior standard deviation. Defaults to 0.01. + scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility. + fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise. + prior_params: Optional dictionary of prior parameters. If provided, specified prior parameters will be used. + epsilon: Small value to avoid division by zero, assigned as Dirichlet parameters for non-existing alleles. + """ replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2) bin_plate = pyro.plate("bin_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + sd_loc = torch.zeros((data.n_edits,)) + sd_scale = ( + torch.ones( + data.n_edits, + ) + * sd_scale + ) + mu_dist = dist.Laplace(0, 1) + if prior_params is not None: + if "sd_loc" in prior_params: + sd_loc = prior_params["sd_loc"] + if "sd_scale" in prior_params: + sd_scale = prior_params["sd_scale"] + if "mu_loc" in prior_params or "mu_scale" in prior_params: + mu_loc = 0.0 + mu_scale = 1.0 + if "mu_loc" in prior_params: + mu_loc = prior_params["mu_loc"] + if "mu_scale" in prior_params: + mu_scale = prior_params["mu_scale"] + mu_dist = dist.Normal(mu_loc, mu_scale) + # Set the prior for phenotype means with pyro.plate("guide_plate1", data.n_edits): - mu_edits = pyro.sample("mu_alleles", dist.Laplace(0, 1)) + mu_edits = pyro.sample("mu_targets", mu_dist) sd_edits = pyro.sample( - "sd_alleles", + "sd_targets", dist.LogNormal( - torch.zeros((data.n_edits,)), - torch.ones( - data.n_edits, - ) - * sd_scale, + sd_loc, + sd_scale, ), ) + assert mu_edits.shape == sd_edits.shape == (data.n_edits,) assert data.allele_to_edit.shape == ( data.n_guides, data.n_max_alleles - 1, data.n_edits, ) - mu_alleles = torch.matmul(data.allele_to_edit, mu_edits) - assert mu_alleles.shape == (data.n_guides, data.n_max_alleles - 1) - sd_alleles = torch.linalg.norm( + mu_targets = torch.matmul(data.allele_to_edit, mu_edits) + assert mu_targets.shape == (data.n_guides, data.n_max_alleles - 1) + sd_targets = torch.linalg.norm( data.allele_to_edit * sd_edits[None, None, :], dim=-1 ) # Frobenius 2-norm - mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_alleles], axis=-1) - sd = torch.cat([torch.ones((data.n_guides, 1)), sd_alleles], axis=-1) + mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_targets], axis=-1) + sd = torch.cat([torch.ones((data.n_guides, 1)), sd_targets], axis=-1) assert mu.shape == sd.shape == (data.n_guides, data.n_max_alleles), ( mu.shape, sd.shape, @@ -795,6 +751,130 @@ def MultiMixtureNormalModel( ) +def NormalGuide(data): + with pyro.plate("guide_plate0", 1): + with pyro.plate("guide_plate1", data.n_targets): + mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) + mu_scale = pyro.param( + "mu_scale", + torch.ones((data.n_targets, 1)), + constraint=constraints.positive, + ) + pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + sd_loc = pyro.param("sd_loc", torch.zeros((data.n_targets, 1))) + sd_scale = pyro.param( + "sd_scale", + torch.ones((data.n_targets, 1)), + constraint=constraints.positive, + ) + pyro.sample("sd_targets", dist.LogNormal(sd_loc, sd_scale)) + if hasattr(data, "sample_covariates"): + with pyro.plate("cov_place", data.n_sample_covariates): + mu_cov_loc = pyro.param( + "mu_cov_loc", torch.zeros((data.n_sample_covariates,)) + ) + mu_cov_scale = pyro.param( + "mu_cov_scale", + torch.ones((data.n_sample_covariates,)), + constraint=constraints.positive, + ) + mu_cov = pyro.sample("mu_cov", dist.Normal(mu_cov_loc, mu_cov_scale)) + assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape + + +def MixtureNormalGuide( + data, + alpha_prior: float = 1, + use_bcmatch: bool = True, + scale_by_accessibility: bool = False, + fit_noise: bool = False, +): + """ + Guide for MixtureNormal model + """ + + replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) + guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + + # Set the prior for phenotype means + mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) + mu_scale = pyro.param( + "mu_scale", torch.ones((data.n_targets, 1)), constraint=constraints.positive + ) + sd_loc = pyro.param("sd_loc", torch.zeros((data.n_targets, 1))) + sd_scale = pyro.param( + "sd_scale", torch.ones((data.n_targets, 1)), constraint=constraints.positive + ) + with pyro.plate("guide_plate0", 1): + with pyro.plate("guide_plate1", data.n_targets): + mu_targets = pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + sd_targets = pyro.sample("sd_targets", dist.LogNormal(sd_loc, sd_scale)) + mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) + mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) + assert mu.shape == (data.n_guides, 2) + + sd = torch.cat([torch.ones((data.n_targets, 1)), sd_targets], axis=-1) + sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) + assert sd.shape == (data.n_guides, 2) + # The pi should be Dirichlet distributed instead of independent betas + alpha_pi = pyro.param( + "alpha_pi", + torch.ones( + ( + data.n_guides, + 2, + ) + ) + * alpha_prior, + constraint=constraints.positive, + ) + assert alpha_pi.shape == ( + data.n_guides, + 2, + ), alpha_pi.shape + pi_a_scaled = alpha_pi / alpha_pi.sum(axis=-1)[:, None] * data.pi_a0[:, None] + + with replicate_plate: + with guide_plate: + pi = pyro.sample( + "pi", + dist.Dirichlet( + pi_a_scaled.unsqueeze(0) + .unsqueeze(0) + .expand(data.n_reps, 1, -1, -1) + .clamp(1e-5) + ), + ) + assert pi.shape == ( + data.n_reps, + 1, + data.n_guides, + 2, + ), pi.shape + if scale_by_accessibility: + # Endogenous target site editing rate may be different + pi = scale_pi_by_accessibility( + pi, data.guide_accessibility, fit_noise=fit_noise + ) + + +def ControlNormalGuide(data, mask_thres=10, use_bcmatch=True): + """ + Fit shared mean + """ + # Set the prior for phenotype means + mu_loc = pyro.param("mu_loc", torch.tensor(0.0)) + mu_scale = pyro.param( + "mu_scale", torch.tensor(1.0), constraint=constraints.positive + ) + sd_loc = pyro.param("sd_loc", torch.tensor(0.0)) + sd_scale = pyro.param( + "sd_scale", torch.tensor(1.0), constraint=constraints.positive + ) + pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + pyro.sample("sd_targets", dist.LogNormal(sd_loc, sd_scale)) + + def MultiMixtureNormalGuide( data, alpha_prior=1, @@ -819,23 +899,23 @@ def MultiMixtureNormalGuide( "sd_scale", torch.ones((data.n_edits,)), constraint=constraints.positive ) with pyro.plate("guide_plate1", data.n_edits): - mu_edits = pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) + mu_edits = pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) sd_edits = pyro.sample( - "sd_alleles", + "sd_targets", dist.LogNormal(sd_loc, sd_scale), ) - mu_alleles = torch.matmul(data.allele_to_edit, mu_edits) - assert mu_alleles.shape == (data.n_guides, data.n_max_alleles - 1), ( - mu_alleles.shape, + mu_targets = torch.matmul(data.allele_to_edit, mu_edits) + assert mu_targets.shape == (data.n_guides, data.n_max_alleles - 1), ( + mu_targets.shape, data.n_max_alleles, data.n_edits, ) - sd_alleles = torch.linalg.norm( + sd_targets = torch.linalg.norm( data.allele_to_edit * sd_edits[None, None, :], dim=-1 ) # Frobenius 2-norm - mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_alleles], axis=-1) - sd = torch.cat([torch.ones((data.n_guides, 1)), sd_alleles], axis=-1) + mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_targets], axis=-1) + sd = torch.cat([torch.ones((data.n_guides, 1)), sd_targets], axis=-1) assert mu.shape == sd.shape == (data.n_guides, data.n_max_alleles), ( mu.shape, sd.shape, diff --git a/bean/model/parser.py b/bean/model/parser.py index 5786f1d..105036a 100755 --- a/bean/model/parser.py +++ b/bean/model/parser.py @@ -23,206 +23,229 @@ def parse_args(parser=None): help="Library design type whether to run variant or tiling screen model.\nVariant library design assumes gRNA has specific target variant and bystander edits are ignored. Tiling library design considers all alleles generated by gRNA in reporter.", ) parser.add_argument("bdata_path", type=str, help="Path of an ReporterScreen object") - parser.add_argument( - "--rep-pi", - "-r", - action="store_true", - default=False, - help="Fit replicate specific scaling factor. Recommended to set as True if you expect variable editing activity across biological replicates.", - ) - parser.add_argument( + run_parser = parser.add_argument_group("General run options") + run_parser.add_argument( "--uniform-edit", "-p", action="store_true", default=False, help="Assume uniform editing rate for all guides.", ) - parser.add_argument( + run_parser.add_argument( "--scale-by-acc", action="store_true", default=False, help="Scale guide editing efficiency by the target loci accessibility", ) - parser.add_argument( + run_parser.add_argument( "--acc-bw-path", type=str, default=None, help="Accessibility .bigWig file to be used to assign accessibility of guides.", ) - parser.add_argument( + run_parser.add_argument( "--acc-col", type=str, default=None, help="Column name in bdata.guides that specify raw ATAC-seq signal.", ) - parser.add_argument( - "--const-pi", + + run_parser.add_argument( + "--outdir", + "-o", + default=".", + type=str, + help="Directory to save the run result.", + ) + run_parser.add_argument( + "--result-suffix", + default="", + type=str, + help="Suffix of the output files", + ) + + run_parser.add_argument( + "--cuda", action="store_true", default=False, help="run on GPU" + ) + run_parser.add_argument( + "--fit-negctrl", + action="store_true", default=False, + help="Fit the shared negative control distribution to normalize the fitted parameters", + ) + run_parser.add_argument( + "--dont-fit-noise", # TODO: add check args action="store_true", - help="Use constant pi provided in --guide-activity-col (instead of fitting from reporter data)", ) - parser.add_argument( - "--shrink-alpha", - default=False, + run_parser.add_argument( + "--dont-adjust-confidence-by-negative-control", action="store_true", - help="Instead of using the trend-fitted alpha values, use estimated alpha values for each gRNA that are shrunk towards the fitted trend.", + help="Do not adjust confidence by negative controls. Without this flag, variant mode will use negative control variants, and tiling mode will use the synonymous variants to adjust confidence of the result.", ) - parser.add_argument( + run_parser.add_argument( + "--load-existing", # TODO: add check args + action="store_true", + help="Load existing .pkl file if present.", + ) + run_parser.add_argument( + "--save-raw", # TODO: add check args + action="store_true", + help="Write .pkl file with raw input/output.", + ) + run_parser.add_argument( + "--device", + type=str, + default=None, + help="Optionally use GPU if provided valid GPU device name (ex. cuda:0)", + ) + input_parser = parser.add_argument_group("Input .h5ad formatting") + input_parser.add_argument( "--condition-col", default="condition", type=str, help="Column key in `bdata.samples` that describes experimental condition.", ) - parser.add_argument( + input_parser.add_argument( "--time-col", default="time", type=str, help="Column key in `bdata.samples` that describes time elapsed.", ) - parser.add_argument( + input_parser.add_argument( "--control-condition", default="bulk", type=str, help="Value in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used. Select this as the condition with the least selection- For the sorting screen, use presort (bulk). For the survival screens, use the closest one with T=0.", ) - parser.add_argument( - "--exclude-control-condition-for-inference", - "-ec", - default=False, - action="store_true", - help="Exclude control conditions for inference. Currently only supported for survival screens.", - ) - parser.add_argument( + + input_parser.add_argument( "--replicate-col", default="replicate", type=str, help="Column key in `bdata.samples` that describes experimental replicates.", ) - parser.add_argument( + input_parser.add_argument( "--target-col", default="target", type=str, help="Column key in `bdata.guides` that describes the target element of each guide.", ) - parser.add_argument( + input_parser.add_argument( "--guide-activity-col", "-a", type=str, default=None, help="Column in ReporterScreen.guides DataFrame showing the editing rate estimated via external tools", ) - parser.add_argument( - "--outdir", - "-o", - default=".", - type=str, - help="Directory to save the run result.", - ) - parser.add_argument( - "--result-suffix", - default="", - type=str, - help="Suffix of the output files", - ) - parser.add_argument( + + input_parser.add_argument( "--sorting-bin-upper-quantile-col", "-uq", help="Column name with upper quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData.var)", default="upper_quantile", ) - parser.add_argument( + input_parser.add_argument( "--sorting-bin-lower-quantile-col", "-lq", help="Column name with lower quantile values of each sorting bin in [Reporter]Screen.samples (or AnnData var)", default="lower_quantile", ) - parser.add_argument( - "--alpha-if-overdispersion-fitting-fails", - "-af", - default=None, - type=str, - help="Comma-separated regression coefficient (b0, b1) of log(a0) ~ log(q) that will be used if fitting dispersion on the data fails.", - ) - parser.add_argument("--cuda", action="store_true", default=False, help="run on GPU") - parser.add_argument( + + input_parser.add_argument( "--sample-mask-col", type=str, default="mask", help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.", ) - parser.add_argument( - "--fit-negctrl", - action="store_true", - default=False, - help="Fit the shared negative control distribution to normalize the fitted parameters", - ) - parser.add_argument( + + input_parser.add_argument( "--negctrl-col", type=str, default="target_group", help="Column in bdata.obs specifying if a guide is negative control. If the `bdata.guides[negctrl_col].lower() == negctrl_col_value`, it is treated as negative control guide.", ) - parser.add_argument( + input_parser.add_argument( "--negctrl-col-value", type=str, default="negctrl", help="Column value in bdata.guides specifying if a guide is negative control. If the `bdata.guides[negctrl_col].lower() == negctrl_col_value`, it is treated as negative control guide.", ) - parser.add_argument( + input_parser.add_argument( "--repguide-mask", type=none_or_str, default="repguide_mask", help="n_replicate x n_guide mask to mask the outlier guides. screen.uns[repguide_mask] will be used.", ) - parser.add_argument( - "--device", - type=str, - default=None, - help="Optionally use GPU if provided valid GPU device name (ex. cuda:0)", - ) - parser.add_argument( - "--ignore-bcmatch", - action="store_true", - default=False, - help="If provided, even if the screen object has .X_bcmatch, ignore the count when fitting.", - ) - parser.add_argument( + + input_parser.add_argument( "--allele-df-key", type=str, default=None, help="screen.uns[allele_df_key] will be used as the allele count.", ) - parser.add_argument( + input_parser.add_argument( "--splice-site-path", type=str, default=None, help="Path to splicing site", ) - parser.add_argument( + input_parser.add_argument( "--control-guide-tag", type=none_or_str, default=None, help="If this string is in guide name, treat each guide separately not to mix the position. Used for negative controls.", ) - parser.add_argument( - "--dont-fit-noise", # TODO: add check args - action="store_true", - ) - parser.add_argument( - "--dont-adjust-confidence-by-negative-control", - action="store_true", - help="Do not adjust confidence by negative controls. Without this flag, variant mode will use negative control variants, and tiling mode will use the synonymous variants to adjust confidence of the result.", - ) - parser.add_argument( + + adv_parser = parser.add_argument_group("Advanced arguments for model fitting") + adv_parser.add_argument( "--n-iter", # TODO: add check args type=int, default=2000, help="# of SVI steps taken for inference.", ) - parser.add_argument( - "--load-existing", # TODO: add check args + adv_parser.add_argument( + "--ignore-bcmatch", action="store_true", - help="Load existing .pkl file if present.", + default=False, + help="If provided, even if the screen object has .X_bcmatch, ignore the count when fitting.", + ) + adv_parser.add_argument( + "--prior-params", + type=str, + default=None, + help="Path to the .pkl file with the dictionary containing prior parameters. Useful if your screens are batched into disjoint pool of gRNA libraries. Currently supports `mu_loc`, `mu_scale`, `sd_loc`, `sd_scale` for sorting screens and `mu_loc`, `mu_scale`, `initial_abundance` for survival screens.", + ) + adv_parser.add_argument( + "--rep-pi", + "-r", + action="store_true", + default=False, + help="Fit replicate specific scaling factor. Recommended to set as True if you expect variable editing activity across biological replicates.", + ) + adv_parser.add_argument( + "--const-pi", + default=False, + action="store_true", + help="Use constant pi provided in --guide-activity-col (instead of fitting from reporter data)", + ) + adv_parser.add_argument( + "--shrink-alpha", + default=False, + action="store_true", + help="Instead of using the trend-fitted alpha values, use estimated alpha values for each gRNA that are shrunk towards the fitted trend.", + ) + adv_parser.add_argument( + "--exclude-control-condition-for-inference", + "-ec", + default=False, + action="store_true", + help="Exclude control conditions for inference. Currently only supported for survival screens.", + ) + adv_parser.add_argument( + "--alpha-if-overdispersion-fitting-fails", + "-af", + default=None, + type=str, + help="Comma-separated regression coefficient (b0, b1) of log(a0) ~ log(q) that will be used if fitting dispersion on the data fails.", ) - return parser diff --git a/bean/model/parser_prior.py b/bean/model/parser_prior.py new file mode 100644 index 0000000..5212ffb --- /dev/null +++ b/bean/model/parser_prior.py @@ -0,0 +1,29 @@ +import argparse + + +def parse_args(parser=None): + if parser is None: + parser = argparse.ArgumentParser( + description="Generate prior_param.pkl for two batched runs, where two runs have no overlap but where guides that targeting a single edit are present in both libraries." + ) + parser.add_argument( + "command1", + type=str, + help="bean run command for the first batched run.", + ) + parser.add_argument( + "command2", + type=str, + help="bean run command for the second batched run.", + ) + parser.add_argument( + "raw_run_output1", + type=str, + help="bean run output .pkl path for the first batched run, which should be ran with --save-raw", + ) + parser.add_argument( + "output_path", + type=str, + help="Output path to save prior parameters.", + ) + return parser diff --git a/bean/model/run.py b/bean/model/run.py index 62f8bbf..aafac8b 100755 --- a/bean/model/run.py +++ b/bean/model/run.py @@ -15,6 +15,7 @@ def tqdm(iterable, **kwargs): import logging from functools import partial import pyro +from bean.preprocessing.data_class import ScreenData import bean.model.model as sorting_model import bean.model.survival_model as survival_model @@ -164,6 +165,7 @@ def check_args(args, bdata): ) else: args.popt = None + return args, bdata @@ -208,6 +210,22 @@ def _get_guide_target_info(bdata, args, cols_include=[]): def run_inference( model, guide, data, initial_lr=0.01, gamma=0.1, num_steps=2000, autoguide=False ): + """ + Run the inference process using stochastic variational inference (SVI) for the given model and guide. + + Args: + model: The Pyro model to be used for inference. + guide: The Pyro guide to be used for inference. + data: The ScreenData object to be used in the inference process. + initial_lr: The initial learning rate for optimization (default is 0.01). + gamma: The factor by which the learning rate is decayed at each step (default is 0.1). + num_steps: The number of steps for the inference process (default is 2000). + autoguide: A flag indicating whether autoguide is used (default is False). + + Returns: + Tuple containing the Pyro parameter store and a dictionary with loss information. + """ + pyro.clear_param_store() lrd = gamma ** (1 / num_steps) svi = pyro.infer.SVI( @@ -235,7 +253,9 @@ def run_inference( ) return pyro.get_param_store(), { "loss": losses, - "params": pyro.get_param_store().get_state(), + "params": { + k: pyro.param(k).data.cpu() for k, v in pyro.get_param_store().items() + }, } @@ -315,3 +335,71 @@ def identify_negctrl_model_guide(args, data_has_bcmatch): use_bcmatch=(not args.ignore_bcmatch and data_has_bcmatch), ) return negctrl_model, negctrl_guide + + +from bean.preprocessing.data_class import SortingScreenData + + +def _check_prior_params(param_path: str, ndata: ScreenData): + if os.path.exists(param_path): + with open(param_path, "rb") as f: + prior_params = pkl.load(f) + else: + raise ValueError( + f"Specified prior parameter file --prior-params {param_path} is not found." + ) + if isinstance(ndata, SortingScreenData): + if "sd_loc" in prior_params: + if prior_params["sd_loc"].shape == (ndata.n_targets,): + prior_params["sd_loc"] = prior_params["sd_loc"].reshape(-1, 1) + elif prior_params["sd_loc"].shape != (ndata.n_targets, 1): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['sd_loc'].shape {prior_params['sd_loc'].shape} does not match the number of target variants {(ndata.n_targets, 1)}." + ) + if "sd_scale" in prior_params: + if prior_params["sd_scale"].shape == (ndata.n_targets): + prior_params["sd_scale"] = prior_params["sd_scale"].reshape(-1, 1) + elif prior_params["sd_scale"].shape != (ndata.n_targets, 1): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['sd_scale'].shape {prior_params['sd_scale'].shape} does not match the number of target variants {(ndata.n_targets, 1)}." + ) + if "mu_loc" in prior_params: + if hasattr(prior_params["mu_loc"], "__len__"): + if prior_params["mu_loc"].shape == (ndata.n_targets,): + prior_params["mu_loc"] = prior_params["mu_loc"].reshape(-1, 1) + elif prior_params["mu_loc"].shape != (ndata.n_targets, 1): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['mu_loc'].shape {prior_params['mu_loc'].shape} does not match the number of target variants {(ndata.n_targets, 1)}." + ) + if "mu_scale" in prior_params: + if hasattr(prior_params["mu_scale"], "__len__"): + if prior_params["mu_scale"].shape == (ndata.n_targets): + prior_params["mu_scale"] = prior_params["mu_scale"].reshape(-1, 1) + elif prior_params["mu_scale"].shape != (ndata.n_targets, 1): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['mu_scale'].shape {prior_params['mu_scale'].shape} does not match the number of target variants {(ndata.n_targets, 1)}." + ) + else: + # Survival model + if "initial_abundance" in prior_params: + if prior_params["initial_abundance"].shape != (ndata.n_targets,): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['initial_abundance'].shape does not match the number of guides {(ndata.n_guides, 1)}." + ) + if "mu_loc" in prior_params: + if hasattr(prior_params["mu_loc"], "__len__"): + if prior_params["mu_loc"].shape == (ndata.n_targets): + prior_params["mu_loc"] = prior_params["mu_loc"].reshape(-1, 1) + elif prior_params["mu_loc"].shape != (ndata.n_targets, 1): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['mu_loc'].shape does not match the number of target variants {(ndata.n_targets, 1)}." + ) + if "mu_scale" in prior_params: + if hasattr(prior_params["mu_scale"], "__len__"): + if prior_params["mu_scale"].shape == (ndata.n_targets): + prior_params["mu_scale"] = prior_params["mu_scale"].reshape(-1, 1) + elif prior_params["mu_scale"].shape != (ndata.n_targets, 1): + raise ValueError( + f"Specified prior parameter --prior-params {param_path}: prior_params['mu_scale'].shape does not match the number of target variants {(ndata.n_targets, 1)}." + ) + return prior_params diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 6a33b18..9fdfbc6 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -1,3 +1,4 @@ +from typing import Optional import torch import pyro from pyro import poutine @@ -12,26 +13,47 @@ def NormalModel( - data: VariantSurvivalScreenData, mask_thres: int = 10, use_bcmatch: bool = True + data: VariantSurvivalScreenData, + mask_thres: int = 10, + use_bcmatch: bool = True, + prior_params: Optional[dict] = None, ): """ Fit only on guide counts - Args - -- - data: input data + + Args: + data: input data + mask_thres: threshold for masking guide counts for stability. Defaults to 10. + use_bcmatch: whether to use barcode-matched counts. Defaults to True. + sd_scale: scale for prior standard deviation. Defaults to 0.01 for improved identifiability. + prior_params: prior parameters. If provided, specified prior parameters will be used. """ replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2) time_plate = pyro.plate("time_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + mu_dist = dist.Laplace(0, 1) + initial_abundance = torch.ones(data.n_guides) / data.n_guides + if prior_params is not None: + if "mu_loc" in prior_params or "mu_scale" in prior_params: + mu_loc = 0.0 + mu_scale = 1.0 + if "mu_loc" in prior_params: + mu_loc = prior_params["mu_loc"] + if "mu_scale" in prior_params: + mu_scale = prior_params["mu_scale"] + mu_dist = dist.Normal(mu_loc, mu_scale) + if "initial_abundance" in prior_params: + initial_abundance = prior_params["initial_abundance"] + # Set the prior for phenotype means with pyro.plate("guide_plate0", 1): with pyro.plate("target_plate", data.n_targets): # In survival analysis, fitted effect size is not - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) + mu_targets = pyro.sample("mu_targets", mu_dist) - mu_center = mu_alleles + mu_center = mu_targets mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) r = torch.exp(mu) assert r.shape == (data.n_guides, 1) @@ -39,13 +61,12 @@ def NormalModel( with pyro.plate("replicate_plate0", data.n_reps, dim=-1): q_0 = pyro.sample( "initial_guide_abundance", - dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), + dist.Dirichlet(initial_abundance.unsqueeze(0).expand(data.n_reps, -1)), ) with replicate_plate: with time_plate as t: time = data.timepoints[t] assert time.shape == (data.n_condits,) - # with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)): with guide_plate: alleles_p_time = torch.pow( r.unsqueeze(0).expand((data.n_condits, -1, -1)), @@ -117,8 +138,8 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) # Set the prior for phenotype means - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) - mu = mu_alleles.repeat(data.n_guides).unsqueeze(-1) + mu_targets = pyro.sample("mu_targets", dist.Laplace(0, 1)) + mu = mu_targets.repeat(data.n_guides).unsqueeze(-1) r = torch.exp(mu) with pyro.plate("rep_plate1", data.n_reps, dim=-1): q_0 = pyro.sample( @@ -201,10 +222,19 @@ def MixtureNormalModel( sd_scale: float = 0.01, scale_by_accessibility: bool = False, fit_noise: bool = False, + prior_params: Optional[dict] = None, ): """ + Using the reporter outcome, phenotype of cells with a guide will be modeled as mixture of two normal distributions of edited and unedited cells. + Args: - scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility + data: Input data of type VariantSortingReporterScreenData. + alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1. + use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True. + sd_scale: Scale for the prior standard deviation. Defaults to 0.01. + scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility. + fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise. + prior_params: Optional dictionary of prior parameters. If provided, specified prior parameters will be used. """ torch.autograd.set_detect_anomaly(True) replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) @@ -212,11 +242,25 @@ def MixtureNormalModel( time_plate = pyro.plate("time_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + mu_dist = dist.Laplace(0, 1) + initial_abundance = torch.ones(data.n_guides) / data.n_guides + if prior_params is not None: + if "mu_loc" in prior_params or "mu_scale" in prior_params: + mu_loc = 0.0 + mu_scale = 1.0 + if "mu_loc" in prior_params: + mu_loc = prior_params["mu_loc"] + if "mu_scale" in prior_params: + mu_scale = prior_params["mu_scale"] + mu_dist = dist.Normal(mu_loc, mu_scale) + if "initial_abundance" in prior_params: + initial_abundance = prior_params["initial_abundance"] + # Set the prior for phenotype means with pyro.plate("guide_plate0", 1): with pyro.plate("guide_plate1", data.n_targets): - mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1)) - mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_alleles], axis=-1) + mu_targets = pyro.sample("mu_targets", mu_dist) + mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) assert mu.shape == (data.n_guides, 2) r = torch.exp(mu) @@ -224,7 +268,7 @@ def MixtureNormalModel( with pyro.plate("replicate_plate0", data.n_reps, dim=-1): q_0 = pyro.sample( "initial_guide_abundance", - dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), + dist.Dirichlet(initial_abundance.unsqueeze(0).expand(data.n_reps, -1)), ) alpha_pi = pyro.param( "alpha_pi", @@ -341,117 +385,6 @@ def MixtureNormalModel( ) -def NormalGuide(data): - initial_abundance = pyro.param( - "initial_abundance", - torch.ones(data.n_guides) / data.n_guides, - constraint=constraints.positive, - ) - with pyro.plate("replicate_plate0", data.n_reps, dim=-1): - q_0 = pyro.sample( - "initial_guide_abundance", - dist.Dirichlet(initial_abundance), - ) - with pyro.plate("guide_plate0", 1): - with pyro.plate("guide_plate1", data.n_targets): - mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) - mu_scale = pyro.param( - "mu_scale", - torch.ones((data.n_targets, 1)), - constraint=constraints.positive, - ) - pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - - -def MixtureNormalGuide( - data, - alpha_prior: float = 1, - use_bcmatch: bool = True, - scale_by_accessibility: bool = False, - fit_noise: bool = False, -): - """ - Guide for MixtureNormal model - """ - - replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) - guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) - initial_abundance = pyro.param( - "initial_abundance", - torch.ones(data.n_guides) / data.n_guides, - constraint=constraints.positive, - ) - with pyro.plate("replicate_plate0", data.n_reps, dim=-1): - q_0 = pyro.sample( - "initial_guide_abundance", - dist.Dirichlet(initial_abundance), - ) - # Set the prior for phenotype means - mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) - mu_scale = pyro.param( - "mu_scale", torch.ones((data.n_targets, 1)), constraint=constraints.positive - ) - with pyro.plate("guide_plate0", 1): - with pyro.plate("guide_plate1", data.n_targets): - mu_alleles = pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_alleles], axis=-1) - mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) - assert mu.shape == (data.n_guides, 2) - - # The pi should be Dirichlet distributed instead of independent betas - alpha_pi = pyro.param( - "alpha_pi", - torch.ones( - ( - data.n_guides, - 2, - ) - ) - * alpha_prior, - constraint=constraints.positive, - ) - assert alpha_pi.shape == ( - data.n_guides, - 2, - ), alpha_pi.shape - pi_a_scaled = alpha_pi / alpha_pi.sum(axis=-1)[:, None] * data.pi_a0[:, None] - - with replicate_plate: - with guide_plate: - pi = pyro.sample( - "pi", - dist.Dirichlet( - pi_a_scaled.unsqueeze(0) - .unsqueeze(0) - .expand(data.n_reps, 1, -1, -1) - .clamp(1e-5) - ), - ) - assert pi.shape == ( - data.n_reps, - 1, - data.n_guides, - 2, - ), pi.shape - if scale_by_accessibility: - # Endogenous target site editing rate may be different - pi = scale_pi_by_accessibility( - pi, data.guide_accessibility, fit_noise=fit_noise - ) - - -def ControlNormalGuide(data, mask_thres=10, use_bcmatch=True): - """ - Fit shared mean - """ - # Set the prior for phenotype means - mu_loc = pyro.param("mu_loc", torch.tensor(0.0)) - mu_scale = pyro.param( - "mu_scale", torch.tensor(1.0), constraint=constraints.positive - ) - pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - - def MultiMixtureNormalModel( data: TilingSurvivalReporterScreenData, alpha_prior=1, @@ -459,34 +392,62 @@ def MultiMixtureNormalModel( sd_scale=0.01, norm_pi=False, scale_by_accessibility=False, - epsilon=1e-5, fit_noise: bool = False, + prior_params: Optional[dict] = None, + epsilon=1e-5, ): - """Tiling version of MixtureNormalModel""" + """ + Using the reporter outcome, phenotype of cells with a guide will be modeled as mixture of normal distributions of all major alleles (including WT) produced by the guide. + + Args: + data: Input data of type VariantSortingReporterScreenData. + alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1. + use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True. + sd_scale: Scale for the prior standard deviation. Defaults to 0.01. + scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility. + fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise. + prior_params: Optional dictionary of prior parameters. If provided, specified prior parameters will be used. + epsilon: Small value to avoid division by zero, assigned as Dirichlet parameters for non-existing alleles. + """ replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2) time_plate = pyro.plate("time_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + + mu_dist = dist.Laplace(0, 1) + initial_abundance = torch.ones(data.n_guides) / data.n_guides + if prior_params is not None: + if "mu_loc" in prior_params or "mu_scale" in prior_params: + mu_loc = 0.0 + mu_scale = 1.0 + if "mu_loc" in prior_params: + mu_loc = prior_params["mu_loc"] + if "mu_scale" in prior_params: + mu_scale = prior_params["mu_scale"] + mu_dist = dist.Normal(mu_loc, mu_scale) + if "initial_abundance" in prior_params: + initial_abundance = prior_params["initial_abundance"] + # Set the prior for phenotype means with pyro.plate("guide_plate1", data.n_edits): - mu_edits = pyro.sample("mu_alleles", dist.Laplace(0, 1)) + mu_edits = pyro.sample("mu_targets", mu_dist) assert mu_edits.shape == (data.n_edits,) assert data.allele_to_edit.shape == ( data.n_guides, data.n_max_alleles - 1, data.n_edits, ) - mu_alleles = torch.matmul(data.allele_to_edit, mu_edits) - assert mu_alleles.shape == (data.n_guides, data.n_max_alleles - 1) + mu_targets = torch.matmul(data.allele_to_edit, mu_edits) + assert mu_targets.shape == (data.n_guides, data.n_max_alleles - 1) - mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_alleles], axis=-1) + mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_targets], axis=-1) r = torch.exp(mu) with pyro.plate("replicate_plate0", data.n_reps, dim=-1): q_0 = pyro.sample( "initial_guide_abundance", - dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), + dist.Dirichlet(initial_abundance.unsqueeze(0).expand(data.n_reps, -1)), ) # The pi should be Dirichlet distributed instead of independent betas alpha_pi0 = ( @@ -630,6 +591,113 @@ def MultiMixtureNormalModel( raise e +def NormalGuide(data): + initial_abundance = pyro.param( + "initial_abundance", + torch.ones(data.n_guides) / data.n_guides, + constraint=constraints.positive, + ) + with pyro.plate("replicate_plate0", data.n_reps, dim=-1): + q_0 = pyro.sample( + "initial_guide_abundance", + dist.Dirichlet(initial_abundance), + ) + with pyro.plate("guide_plate0", 1): + with pyro.plate("guide_plate1", data.n_targets): + mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) + mu_scale = pyro.param( + "mu_scale", + torch.ones((data.n_targets, 1)), + constraint=constraints.positive, + ) + pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + + +def MixtureNormalGuide( + data, + alpha_prior: float = 1, + use_bcmatch: bool = True, + scale_by_accessibility: bool = False, + fit_noise: bool = False, +): + replicate_plate = pyro.plate("rep_plate", data.n_reps, dim=-3) + guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) + initial_abundance = pyro.param( + "initial_abundance", + torch.ones(data.n_guides) / data.n_guides, + constraint=constraints.positive, + ) + with pyro.plate("replicate_plate0", data.n_reps, dim=-1): + q_0 = pyro.sample( + "initial_guide_abundance", + dist.Dirichlet(initial_abundance), + ) + # Set the prior for phenotype means + mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets, 1))) + mu_scale = pyro.param( + "mu_scale", torch.ones((data.n_targets, 1)), constraint=constraints.positive + ) + with pyro.plate("guide_plate0", 1): + with pyro.plate("guide_plate1", data.n_targets): + mu_targets = pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + mu_center = torch.cat([torch.zeros((data.n_targets, 1)), mu_targets], axis=-1) + mu = torch.repeat_interleave(mu_center, data.target_lengths, dim=0) + assert mu.shape == (data.n_guides, 2) + + # The pi should be Dirichlet distributed instead of independent betas + alpha_pi = pyro.param( + "alpha_pi", + torch.ones( + ( + data.n_guides, + 2, + ) + ) + * alpha_prior, + constraint=constraints.positive, + ) + assert alpha_pi.shape == ( + data.n_guides, + 2, + ), alpha_pi.shape + pi_a_scaled = alpha_pi / alpha_pi.sum(axis=-1)[:, None] * data.pi_a0[:, None] + + with replicate_plate: + with guide_plate: + pi = pyro.sample( + "pi", + dist.Dirichlet( + pi_a_scaled.unsqueeze(0) + .unsqueeze(0) + .expand(data.n_reps, 1, -1, -1) + .clamp(1e-5) + ), + ) + assert pi.shape == ( + data.n_reps, + 1, + data.n_guides, + 2, + ), pi.shape + if scale_by_accessibility: + # Endogenous target site editing rate may be different + pi = scale_pi_by_accessibility( + pi, data.guide_accessibility, fit_noise=fit_noise + ) + + +def ControlNormalGuide(data, mask_thres=10, use_bcmatch=True): + """ + Fit shared mean + """ + # Set the prior for phenotype means + mu_loc = pyro.param("mu_loc", torch.tensor(0.0)) + mu_scale = pyro.param( + "mu_scale", torch.tensor(1.0), constraint=constraints.positive + ) + pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + + def MultiMixtureNormalGuide( data, alpha_prior=1, @@ -660,15 +728,15 @@ def MultiMixtureNormalGuide( "mu_scale", torch.ones((data.n_edits,)), constraint=constraints.positive ) with pyro.plate("guide_plate1", data.n_edits): - mu_edits = pyro.sample("mu_alleles", dist.Normal(mu_loc, mu_scale)) - mu_alleles = torch.matmul(data.allele_to_edit, mu_edits) - assert mu_alleles.shape == (data.n_guides, data.n_max_alleles - 1), ( - mu_alleles.shape, + mu_edits = pyro.sample("mu_targets", dist.Normal(mu_loc, mu_scale)) + mu_targets = torch.matmul(data.allele_to_edit, mu_edits) + assert mu_targets.shape == (data.n_guides, data.n_max_alleles - 1), ( + mu_targets.shape, data.n_max_alleles, data.n_edits, ) - mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_alleles], axis=-1) + mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_targets], axis=-1) assert mu.shape == (data.n_guides, data.n_max_alleles), (mu.shape,) # The pi should be Dirichlet distributed instead of independent betas alpha_pi0 = ( diff --git a/bean/model/utils.py b/bean/model/utils.py index 35352ec..6363049 100755 --- a/bean/model/utils.py +++ b/bean/model/utils.py @@ -4,7 +4,8 @@ import pyro.distributions as dist import pyro.distributions.constraints as constraints -MAX_LOGPI=10 +MAX_LOGPI = 10 + def get_alpha( expected_guide_p, size_factor, sample_mask, a0, epsilon=1e-5, normalize_by_a0=True diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index fe38b6a..5e3da55 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -258,10 +258,8 @@ class ReporterScreenData(ScreenData): size_factor_bcmatch: torch.Tensor X_bcmatch_control: torch.Tensor size_factor_bcmatch_control: torch.Tensor - allele_counts: torch.Tensor allele_counts_control: torch.Tensor a0_bcmatch: torch.Tensor - a0_allele: torch.Tensor pi_a0: torch.Tensor def __init__( diff --git a/docs/ReporterScreen_api_files/ReporterScreen_api_35_1.png b/docs/ReporterScreen_api_files/ReporterScreen_api_35_1.png new file mode 100644 index 0000000..b5b724d Binary files /dev/null and b/docs/ReporterScreen_api_files/ReporterScreen_api_35_1.png differ diff --git a/docs/_build_prior.md b/docs/_build_prior.md new file mode 100644 index 0000000..2f7969b --- /dev/null +++ b/docs/_build_prior.md @@ -0,0 +1,3 @@ +# Build custom prior for batch integration + +This is a helper function that is used for batch integration when the library is split into two disjoint sublibraries. For the usage, see the [batch integration tutorial](https://pinellolab.github.io/crispr-bean/tutorial_custom_prior.html). \ No newline at end of file diff --git a/docs/_model.md b/docs/_model.md new file mode 100644 index 0000000..2fd9f95 --- /dev/null +++ b/docs/_model.md @@ -0,0 +1 @@ +TBD \ No newline at end of file diff --git a/docs/_run.md b/docs/_run.md index acfebb9..53966d8 100755 --- a/docs/_run.md +++ b/docs/_run.md @@ -4,7 +4,7 @@ BEAN uses Bayesian network to incorporate gRNA editing outcome to provide poster 2. The weight of the mixture components are inferred from the reporter editing outcome and the chromatin accessibility of the loci. 3. Cells with each gRNA, formulated as the mixture distribution, is sorted by the phenotypic quantile to produce the gRNA counts. -For the full detail, see the method section of the [BEAN manuscript](https://www.medrxiv.org/content/10.1101/2023.09.08.23295253v1). +For the full detail on modeling, see the [model description](https://pinellolab.github.io/crispr-bean/model.html). model diff --git a/docs/_tutorial_custom_prior.md b/docs/_tutorial_custom_prior.md new file mode 100644 index 0000000..843a868 --- /dev/null +++ b/docs/_tutorial_custom_prior.md @@ -0,0 +1,32 @@ +# Feeding custom prior into `bean run` + +In this tutorial, we consider a case where user want to specify per-variant prior beliefs into the model using `--prior-params=prior_params.pkl` for `bean run`. +In particular, we see a case where the library is split into two sublibraries, where guides are disjoint but many guides in two sublibraries may target shared variant. In this particular case, we provide helper command `bean build-prior`. + +## Example workflow +```bash +screen_id1=var_mini_screen_sub1 +screen_id2=var_mini_screen_sub2 +working_dir=tests/data/ +output_dir=tests/test_res/ + +# 1. Run sublibrary 1 +# It is important that you feed --save-raw so the run output will contain the input for the next step. +bean run sorting variant $working_dir/${screen_id1}.h5ad --control-condition D14_1 -o $working_dir --fit-negctrl --save-raw + +# 2. Build prior +# Feed first and second `bean run` scripts, and output file pickle file path that will store prior_params +# Usage: bean build-prior command_run1 command_run2 raw_output_run1 param_output +bean build-prior \ +"bean run sorting variant $working_dir/${screen_id1}.h5ad --control-condition D14_1 -o $working_dir --fit-negctrl --save-raw" \ +"bean run sorting variant $working_dir/${screen_id2}.h5ad --control-condition D14_2 -o $working_dir --fit-negctrl" \ +$working_dir/bean_run_result.${screen_id1}/MixtureNormal.result.pkl \ +$working_dir/prior_params.pkl + +# 3. Run sublibrary 2 with the specified prior +# Feed in the prior_param.pkl file from the previous step. +bean run sorting variant $working_dir/${screen_id1}.h5ad --control-condition D14_2 -o BE_part1_variant_masked_sorting --fit-negctrl --prior-param $working_dir/prior_params.pkl +``` + +## Manually specifying prior +TBD. If this function is desired please open an issue for expedited documentation :) \ No newline at end of file diff --git a/docs/build_prior.rst b/docs/build_prior.rst new file mode 100644 index 0000000..d7d9e4c --- /dev/null +++ b/docs/build_prior.rst @@ -0,0 +1,10 @@ +`bean build-prior` +*********************** +.. mdinclude:: _build_prior.md + +Full parameters +================== +.. argparse:: + :filename: ../bean/model/parser_prior.py + :func: parse_args + :prog: bean build-prior \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e5ba831..b385d81 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,11 +14,7 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = [ - "sphinxarg.ext", - "m2r", - "sphinx.ext.extlinks", -] +extensions = ["sphinxarg.ext", "m2r", "sphinx.ext.extlinks", "sphinx.ext.mathjax"] templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] diff --git a/docs/filter_alleles.ipynb b/docs/filter_alleles.ipynb new file mode 100644 index 0000000..7fdeef1 --- /dev/null +++ b/docs/filter_alleles.ipynb @@ -0,0 +1,1103 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import argparse\n", + "import logging\n", + "import bean as be\n", + "from bean.plotting.allele_stats import plot_n_alleles_per_guide, plot_n_guides_per_edit\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.style.use(\"default\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "bdata_path = \"/PHShome/jr1025//projects/ANBE/anbe_manuscript/workflow/results/filtered_annotated/LDLRCDS/bean_count_LDLRCDS_masked.h5ad\"\n", + "out_bdata_path = bdata_path = \"/PHShome/jr1025//projects/ANBE/anbe_manuscript/workflow/results/filtered_annotated/LDLRCDS/bean_count_LDLRCDS_alleleFiltered.h5ad\"\n", + "plasmid_path = \"/PHShome/jr1025//projects/ANBE/anbe_manuscript/workflow/results/mapped/LDLRCDS/bean_count_LDLRCDS_plasmid.h5ad\"\n", + "edit_start_pos = 2\n", + "edit_end_pos = 7\n", + "filter_allele_proportion = 0.05\n", + "filter_sample_proportion = 0.2\n", + "jaccard_threshold = 0.5\n", + "output_prefix=bdata_path.rsplit(\".h5ad\", 1)[0] + \"_alleleFiltered\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "bdata = be.read_h5ad(bdata_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting from .uns['allele_counts'] with 987700 alleles.\n" + ] + } + ], + "source": [ + "print(\n", + " f\"Starting from .uns['allele_counts'] with {len(bdata.uns['allele_counts'])} alleles.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "allele_df_keys = [\"allele_counts\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running Fisher's exact test to get significant edits compared to control...\n", + "Done calculating significance.\n", + "\n", + "\n", + "Filtering alleles for those containing significant edits (q < 0.05)...\n", + "Running 30 parallel processes to filter alleles...\n", + "Done filtering alleles, merging the result...\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": "{l_bar}{bar}{r_bar}", + "colour": null, + "elapsed": 0.016897201538085938, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "Mapping alleles to closest filtered alleles", + "rate": null, + "total": 7469, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "38d1de311abf44cd829ce4c3109c0a65", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Mapping alleles to closest filtered alleles: 0%| | 0/7469 [00:00G\n", + "ref:C at pos 1415, got edit 11224268:9:+:A>G\n", + "ref:C at pos 1415, got edit 11224268:9:+:A>G\n", + "Cannot translate codon due to ambiguity: GG-\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "ref:T at pos 1958, got edit 11230881:11:+:C>G\n", + "ref:C at pos 1973, got edit 11230896:10:-:A>G\n", + "ref:C at pos 1973, got edit 11230896:11:-:A>G\n", + "Cannot translate codon due to ambiguity: T-C\n", + "ref:A at pos 2231, got edit 11233941:10:-:C>T\n", + "ref:A at pos 2231, got edit 11233941:10:-:C>T\n", + "ref:C at pos 2248, got edit 11233958:9:-:A>G\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "reached the end of CDS, frameshift.\n", + "Filtered down to 10278 alleles.\n" + ] + } + ], + "source": [ + "print(\n", + " \"Translating alleles...\"\n", + ") # TODO: Check & document custom fasta file for translation\n", + "filtered_key = f\"{allele_df_keys[-1]}_translated\"\n", + "bdata.uns[filtered_key] = be.translate_allele_df(\n", + " bdata.uns[allele_df_keys[-1]], \n", + ").rename(columns={\"allele\": \"aa_allele\"})\n", + "allele_df_keys.append(filtered_key)\n", + "print(f\"Filtered down to {len(bdata.uns[filtered_key])} alleles.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filter by allele proportion" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Filtering alleles for those have allele fraction 0.05 in at least 20.0% of samples...\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": "{l_bar}{bar}{r_bar}", + "colour": null, + "elapsed": 0.01263427734375, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "", + "rate": null, + "total": 5169, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/5169 [00:00" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(len(allele_df_keys), 2, figsize=(6, 3 * len(allele_df_keys)))\n", + "for i, key in enumerate(allele_df_keys):\n", + " plot_n_alleles_per_guide(bdata, key, bdata.uns[key].columns[1], ax[i, 0])\n", + " plot_n_guides_per_edit(bdata, key, bdata.uns[key].columns[1], ax[i, 1])\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{ output_prefix}.filtered_allele_stats.pdf\", bbox_inches=\"tight\")\n", + "print(\n", + " f\"Saving plotting result and log at { output_prefix}.[filtered_allele_stats.pdf, filter_log.txt].\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Write log" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "with open(f\"{output_prefix}.filter_log.txt\", \"w\") as out_log:\n", + " for key in allele_df_keys:\n", + " out_log.write(f\"{key}\\t{len(bdata.uns[key])}\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save result" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "bdata.write(out_bdata_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
RegionposstrandsequenceReporterbarcode5-nt PAMtargetPospos_seqguide_lenstart_postarget_startmasked_sequencemasked_barcodeGroupedit_rate
name
5' UTR+pro_28_neg5' UTR+pro28negTCCTTCGAAAGTGTCGCCAGCTCCAGTCCTTCGAAAGTGTCGCCAGGGCAGGCCTCGGCAG11199966.0CTGGCGACACTTTCGAAGGA20111999525.0TCCTTCGGGGGTGTCGCCGGCCTC5' UTR+proNaN
5' UTR+pro_28_pos5' UTR+pro28posCTTTCGAAGGACTGGAGTGGGCGACACTTTCGAAGGACTGGAGTGGGAATCACAAAGAATC11199966.0CTTTCGAAGGACTGGAGTGG20111999615.0CTTTCGGGGGGCTGGGGTGGCGGG5' UTR+pro0.323475
5' UTR+pro_31_neg5' UTR+pro31negCAGTCCTTCGAAAGTGTCGCCCACTCCAGTCCTTCGAAAGTGTCGCCAGGGCCCTACAGGG11199969.0GCGACACTTTCGAAGGACTG20111999555.0CGGTCCTTCGGGGGTGTCGCCCTG5' UTR+proNaN
5' UTR+pro_31_pos5' UTR+pro31posTCGAAGGACTGGAGTGGGAAACACTTTCGAAGGACTGGAGTGGGAATCAGAGTGCCTCAGA11199969.0TCGAAGGACTGGAGTGGGAA20111999645.0TCGGGGGGCTGGGGTGGGGGTGCC5' UTR+pro0.026911
5' UTR+pro_34_neg5' UTR+pro34negCTCCAGTCCTTCGAAAGTGTTTCCCACTCCAGTCCTTCGAAAGTGTCGCCAGCTGACGCCA11199972.0ACACTTTCGAAGGACTGGAG20111999585.0CTCCGGTCCTTCGGGGGTGTCTGG5' UTR+pro0.340677
...................................................
CBE_CONTROL_40_posCBE control40NaNTATCGCGCTTGGGTTATACGGAAAAATATCGCGCTTGGGTTATACGCTCCAATACCCTCCANaNTATCGCGCTTGGGTTATACG20-1NaNTGTCGCGCTTGGGTTGTGCGTGCCCBE controlNaN
CBE_CONTROL_96_posCBE control96NaNATTAGCCGTTGCCATATCAATCCGTTATTAGCCGTTGCCATATCAAATGAGAGCTCATGAGNaNATTAGCCGTTGCCATATCAA20-1NaNGTTGGCCGTTGCCGTGTCGGGCTCCBE control0.772672
CBE_CONTROL_97_posCBE control97NaNGTCCCTCAGGGTGCAACTTTCTAAGGGTCCCTCAGGGTGCAACTTTGGTACAGAATGGTANaNGTCCCTCAGGGTGCAACTT19-1NaNGTCCCTCGGGGTGCGGCTTGGGGCBE controlNaN
CBE_CONTROL_98_posCBE control98NaNTCCTCATCCGGTCAGGCTGGTGAAAGTCCTCATCCGGTCAGGCTGGTGTTAGAGCGTGTTNaNTCCTCATCCGGTCAGGCTG19-1NaNTCCTCGTCCGGTCGGGCTGGGGCCBE control0.573492
CBE_CONTROL_99_posCBE control99NaNTAACGCGCATATCTGAACACCTAAGGTAACGCGCATATCTGAACACTGCAATAAGATGCAANaNTAACGCGCATATCTGAACAC20-1NaNTGGCGCGCGTGTCTGGGCGCGGGGCBE control0.004608
\n", + "

7500 rows × 16 columns

\n", + "
" + ], + "text/plain": [ + " Region pos strand sequence \\\n", + "name \n", + "5' UTR+pro_28_neg 5' UTR+pro 28 neg TCCTTCGAAAGTGTCGCCAG \n", + "5' UTR+pro_28_pos 5' UTR+pro 28 pos CTTTCGAAGGACTGGAGTGG \n", + "5' UTR+pro_31_neg 5' UTR+pro 31 neg CAGTCCTTCGAAAGTGTCGC \n", + "5' UTR+pro_31_pos 5' UTR+pro 31 pos TCGAAGGACTGGAGTGGGAA \n", + "5' UTR+pro_34_neg 5' UTR+pro 34 neg CTCCAGTCCTTCGAAAGTGT \n", + "... ... .. ... ... \n", + "CBE_CONTROL_40_pos CBE control 40 NaN TATCGCGCTTGGGTTATACG \n", + "CBE_CONTROL_96_pos CBE control 96 NaN ATTAGCCGTTGCCATATCAA \n", + "CBE_CONTROL_97_pos CBE control 97 NaN GTCCCTCAGGGTGCAACTT \n", + "CBE_CONTROL_98_pos CBE control 98 NaN TCCTCATCCGGTCAGGCTG \n", + "CBE_CONTROL_99_pos CBE control 99 NaN TAACGCGCATATCTGAACAC \n", + "\n", + " Reporter barcode 5-nt PAM \\\n", + "name \n", + "5' UTR+pro_28_neg CTCCAGTCCTTCGAAAGTGTCGCCAGGGCAGG CCTC GGCAG \n", + "5' UTR+pro_28_pos GCGACACTTTCGAAGGACTGGAGTGGGAATCA CAAA GAATC \n", + "5' UTR+pro_31_neg CCACTCCAGTCCTTCGAAAGTGTCGCCAGGGC CCTA CAGGG \n", + "5' UTR+pro_31_pos ACACTTTCGAAGGACTGGAGTGGGAATCAGAG TGCC TCAGA \n", + "5' UTR+pro_34_neg TTCCCACTCCAGTCCTTCGAAAGTGTCGCCAG CTGA CGCCA \n", + "... ... ... ... \n", + "CBE_CONTROL_40_pos GAAAAATATCGCGCTTGGGTTATACGCTCCAA TACC CTCCA \n", + "CBE_CONTROL_96_pos TCCGTTATTAGCCGTTGCCATATCAAATGAGA GCTC ATGAG \n", + "CBE_CONTROL_97_pos TCTAAGGGTCCCTCAGGGTGCAACTTTGGTAC AGAA TGGTA \n", + "CBE_CONTROL_98_pos GTGAAAGTCCTCATCCGGTCAGGCTGGTGTTA GAGC GTGTT \n", + "CBE_CONTROL_99_pos CTAAGGTAACGCGCATATCTGAACACTGCAAT AAGA TGCAA \n", + "\n", + " targetPos pos_seq guide_len start_pos \\\n", + "name \n", + "5' UTR+pro_28_neg 11199966.0 CTGGCGACACTTTCGAAGGA 20 11199952 \n", + "5' UTR+pro_28_pos 11199966.0 CTTTCGAAGGACTGGAGTGG 20 11199961 \n", + "5' UTR+pro_31_neg 11199969.0 GCGACACTTTCGAAGGACTG 20 11199955 \n", + "5' UTR+pro_31_pos 11199969.0 TCGAAGGACTGGAGTGGGAA 20 11199964 \n", + "5' UTR+pro_34_neg 11199972.0 ACACTTTCGAAGGACTGGAG 20 11199958 \n", + "... ... ... ... ... \n", + "CBE_CONTROL_40_pos NaN TATCGCGCTTGGGTTATACG 20 -1 \n", + "CBE_CONTROL_96_pos NaN ATTAGCCGTTGCCATATCAA 20 -1 \n", + "CBE_CONTROL_97_pos NaN GTCCCTCAGGGTGCAACTT 19 -1 \n", + "CBE_CONTROL_98_pos NaN TCCTCATCCGGTCAGGCTG 19 -1 \n", + "CBE_CONTROL_99_pos NaN TAACGCGCATATCTGAACAC 20 -1 \n", + "\n", + " target_start masked_sequence masked_barcode \\\n", + "name \n", + "5' UTR+pro_28_neg 5.0 TCCTTCGGGGGTGTCGCCGG CCTC \n", + "5' UTR+pro_28_pos 5.0 CTTTCGGGGGGCTGGGGTGG CGGG \n", + "5' UTR+pro_31_neg 5.0 CGGTCCTTCGGGGGTGTCGC CCTG \n", + "5' UTR+pro_31_pos 5.0 TCGGGGGGCTGGGGTGGGGG TGCC \n", + "5' UTR+pro_34_neg 5.0 CTCCGGTCCTTCGGGGGTGT CTGG \n", + "... ... ... ... \n", + "CBE_CONTROL_40_pos NaN TGTCGCGCTTGGGTTGTGCG TGCC \n", + "CBE_CONTROL_96_pos NaN GTTGGCCGTTGCCGTGTCGG GCTC \n", + "CBE_CONTROL_97_pos NaN GTCCCTCGGGGTGCGGCTT GGGG \n", + "CBE_CONTROL_98_pos NaN TCCTCGTCCGGTCGGGCTG GGGC \n", + "CBE_CONTROL_99_pos NaN TGGCGCGCGTGTCTGGGCGC GGGG \n", + "\n", + " Group edit_rate \n", + "name \n", + "5' UTR+pro_28_neg 5' UTR+pro NaN \n", + "5' UTR+pro_28_pos 5' UTR+pro 0.323475 \n", + "5' UTR+pro_31_neg 5' UTR+pro NaN \n", + "5' UTR+pro_31_pos 5' UTR+pro 0.026911 \n", + "5' UTR+pro_34_neg 5' UTR+pro 0.340677 \n", + "... ... ... \n", + "CBE_CONTROL_40_pos CBE control NaN \n", + "CBE_CONTROL_96_pos CBE control 0.772672 \n", + "CBE_CONTROL_97_pos CBE control NaN \n", + "CBE_CONTROL_98_pos CBE control 0.573492 \n", + "CBE_CONTROL_99_pos CBE control 0.004608 \n", + "\n", + "[7500 rows x 16 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bdata.guides" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (anbe_py38)", + "language": "python", + "name": "myenv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/index.rst b/docs/index.rst index 7aaedda..ae36ced 100755 --- a/docs/index.rst +++ b/docs/index.rst @@ -42,6 +42,14 @@ Workflows tutorials +Model description +-------------------------- +.. toctree:: + :maxdepth: 2 + + model + + API references -------------------------- .. toctree:: diff --git a/docs/model.rst b/docs/model.rst new file mode 100644 index 0000000..63e8365 --- /dev/null +++ b/docs/model.rst @@ -0,0 +1,5 @@ +.. _model: + +BEAN model +*********************** +.. mdinclude:: _model.md diff --git a/docs/subcommands.rst b/docs/subcommands.rst index 8842def..7432ac3 100755 --- a/docs/subcommands.rst +++ b/docs/subcommands.rst @@ -12,4 +12,5 @@ Subcommands qc filter run - create_screen \ No newline at end of file + create_screen + build_prior \ No newline at end of file diff --git a/docs/tutorial_custom_prior.rst b/docs/tutorial_custom_prior.rst new file mode 100644 index 0000000..9500494 --- /dev/null +++ b/docs/tutorial_custom_prior.rst @@ -0,0 +1,5 @@ +.. _tutorial_custom_prior: + +.. mdinclude:: _tutorial_custom_prior.md + +See :ref:`subcommands` for the full details. diff --git a/docs/tutorials.rst b/docs/tutorials.rst index c098b9b..cec0918 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -11,4 +11,5 @@ Workflow tutorials tutorial_prolif_gwas tutorial_prolif_cds tutorial_no_edit - Visualize result (variant screens) \ No newline at end of file + Visualize result (variant screens) + tutorial_custom_prior \ No newline at end of file diff --git a/setup.py b/setup.py index ca388aa..c492064 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="crispr-bean", - version="1.2.8", + version="1.2.9", python_requires=">=3.8.0", author="Jayoung Ryu", author_email="jayoung_ryu@g.harvard.edu", diff --git a/tests/test_run.py b/tests/test_run.py index 7f60b85..99680bf 100755 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -224,3 +224,8 @@ def test_run_tiling_no_translation(): ) except subprocess.CalledProcessError as exc: raise exc + + +# TODO: make test data for --prior-params +# TODO: semantic testing on splitting a single screen into two to train +# TODO: add test using --prior-params