Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Jun 20, 2024
2 parents 07f858e + 366944c commit b57f288
Show file tree
Hide file tree
Showing 29 changed files with 1,996 additions and 476 deletions.
44 changes: 0 additions & 44 deletions .github/workflows/pypi_release.yml

This file was deleted.

12 changes: 6 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Changelog
## 1.2.5
* Allow `bean run .. tiling` for untranslated `--allele-df-key`.

## 1.2.8
* Change .pyx files to be compatible with more recent numpy versions
## 1.2.7
* **CRITICAL** Fix sample ordering & masking issue for survival screens
## 1.2.6
* Fix overflow in `bean run survival` and autograde error related to inplace assignment for `bean run survival tiling`.

## 1.2.7
* **CRITICAL** Fix sample ordering & masking issue for survival screens
## 1.2.5
* Allow `bean run .. tiling` for untranslated `--allele-df-key`.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
2. [`profile`](https://pinellolab.github.io/crispr-bean/profile.html): Profile editing preferences of your editor.
3. [`qc`](https://pinellolab.github.io/crispr-bean/qc.html): Quality control report and filtering out / masking of aberrant sample and guides
4. [`filter`](https://pinellolab.github.io/crispr-bean/filter.html): Filter reporter alleles; essential for `tiling` mode that allows for all alleles generated from gRNA.
5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data.
5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. **See more about the model in the link**.
* Screen data is saved as [`ReporterScreen` object](https://pinellolab.github.io/crispr-bean/reporterscreen.html) in the pipeline.
BEAN stores mapped gRNA and allele counts in `ReporterScreen` object which is compatible with [AnnData](https://anndata.readthedocs.io/en/latest/index.html).

## Installation
First install [PyTorch](https://pytorch.org/get-started/).
Then download from PyPI:
Expand All @@ -50,6 +51,7 @@ See the [documentation](https://pinellolab.github.io/crispr-bean/) for tutorials
| GWAS variant library | Survival / Proliferation | Yes/No | [GWAS variant screen](https://pinellolab.github.io/crispr-bean/tutorial_prolif_gwas.html)
| Coding sequence tiling libarary | Survival / Proliferation | Yes/No | [Coding sequence tiling screen](https://pinellolab.github.io/crispr-bean/tutorial_prolif_cds.html)
| Perturbation library without reporter | FACS sorting | No | [No reporter screen](https://pinellolab.github.io/crispr-bean/tutorial_no_edit.html)
| Integration of disjoint libraries | Any | Any | [Feeding custom prior](https://pinellolab.github.io/crispr-bean/tutorial_custom_prior.html)

Also see notebook that visualizes screen analysis result [here](https://github.com/pinellolab/crispr-bean/blob/main/docs/visualize_var.ipynb).

Expand Down
72 changes: 72 additions & 0 deletions bean/cli/build_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pickle as pkl
import numpy as np
import torch
from bean.model.run import _get_guide_target_info
from bean.model.parser import parse_args
from bean.cli.run import main as get_screendata
from bean.preprocessing.data_class import SortingScreenData


def generate_prior_data_for_disjoint_library_pair(
command1: str, command2: str, output1_path: str, prior_params_path: str
):
"""Generate prior for a two batches with disjoint guides but with shared variants."""
with open(output1_path, "rb") as f:
data = pkl.load(f)
ndata = data["data"]
parser = parse_args()
command1 = command1.split("bean run ")[-1]
command2 = command2.split("bean run ")[-1]
args = parser.parse_args(command1.split(" "))
args2 = parser.parse_args(command2.split(" "))
ndata2 = get_screendata(args2, return_data=True)
target_df = _get_guide_target_info(
ndata.screen, args, cols_include=[args.negctrl_col]
)
target_df2 = _get_guide_target_info(
ndata2.screen, args2, cols_include=[args2.negctrl_col]
)
batch1_idx = np.where(
target_df.index.map(lambda s: s in target_df2.index.tolist())
)[0]
batch2_idx = []
for i in batch1_idx:
batch2_idx.append(
np.where(target_df.index.tolist()[i] == target_df2.index)[0].item()
)
batch2_idx = np.array(batch2_idx)
if isinstance(ndata, SortingScreenData):
mu_loc = torch.zeros((ndata2.n_targets, 1))
mu_loc[batch2_idx, :] = data["params"]["mu_loc"][batch1_idx, :]
mu_scale = torch.ones((ndata2.n_targets, 1))
mu_scale[batch2_idx, :] = data["params"]["mu_scale"][batch1_idx, :]
sd_loc = torch.zeros((ndata2.n_targets, 1))
sd_loc[batch2_idx, :] = data["params"]["sd_loc"][batch1_idx, :]
sd_scale = torch.ones((ndata2.n_targets, 1)) * 0.01
sd_scale[batch2_idx, :] = data["params"]["sd_scale"][batch1_idx, :]
prior_params = {
"mu_loc": mu_loc,
"mu_scale": mu_scale,
"sd_loc": sd_loc,
"sd_scale": sd_scale,
}
else:
mu_loc = torch.zeros((ndata2.n_targets, 1))
mu_loc[batch2_idx, :] = data["params"]["mu_loc"][batch1_idx, :]
mu_scale = torch.ones((ndata2.n_targets, 1))
mu_scale[batch2_idx, :] = data["params"]["mu_scale"][batch1_idx, :]
prior_params = {
"mu_loc": mu_loc,
"mu_scale": mu_scale,
}
with open(prior_params_path, "wb") as f:
pkl.dump(prior_params, f)
print(
f"Successfully generated prior parameters at {prior_params_path}. To use this parameter, run:\nbean run {command2+' --prior-params '+prior_params_path}"
)


def main(args):
generate_prior_data_for_disjoint_library_pair(
args.command1, args.command2, args.raw_run_output1, args.output_path
)
16 changes: 16 additions & 0 deletions bean/cli/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bean.model.parser import parse_args as get_run_parser
from bean.framework.parser import get_input_parser as get_create_screen_parser
from bean.annotate.utils import get_splice_parser as get_splice_site_parser
from bean.model.parser_prior import parse_args as get_prior_parser
from bean.cli.count import main as count
from bean.cli.count_samples import main as count_samples
from bean.cli.profile import main as profile
Expand All @@ -15,6 +16,15 @@
from bean.cli.run import main as run
from bean.cli.create_screen import main as create_screen
from bean.cli.get_splice_sites import main as get_splice_sites
from bean.cli.build_prior import main as build_prior

import warnings

warnings.filterwarnings(
action="ignore",
category=FutureWarning,
message=r".*The default of observed=False is deprecated and will be changed to True in a future version of pandas.*",
)


def get_parser():
Expand All @@ -40,6 +50,10 @@ def get_parser():
"get-splice-sites", help="get splice sites"
)
splice_site_parser = get_splice_site_parser(splice_site_parser)
prior_parser = subparsers.add_parser(
"build-prior", help="obtain prior_params.pkl for batched runs"
)
prior_parser = get_prior_parser(prior_parser)
return parser


Expand All @@ -65,5 +79,7 @@ def main() -> None:
create_screen(args)
elif args.subcommand == "get-splice-sites":
get_splice_sites(args)
elif args.subcommand == "build-prior":
build_prior(args)
else:
parser.print_help()
20 changes: 16 additions & 4 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
check_args,
identify_model_guide,
identify_negctrl_model_guide,
_check_prior_params,
)

logging.basicConfig(
Expand Down Expand Up @@ -60,7 +61,7 @@
)


def main(args):
def main(args, return_data=False):
print(
r"""
_ _
Expand Down Expand Up @@ -114,7 +115,8 @@ def main(args):
replicate_col=args.replicate_col,
use_bcmatch=(not args.ignore_bcmatch),
)

if return_data:
return ndata
# Build variant dataframe
adj_negctrl_idx = None
if args.library_design == "variant":
Expand Down Expand Up @@ -183,6 +185,11 @@ def main(args):
)
guide_info_df = ndata.screen.guides

# Add user-defined prior.
if args.prior_params is not None:
prior_params = _check_prior_params(args.prior_params, ndata)
model = partial(model, prior_params=prior_params)

# Run the inference steps
info(f"Running inference for {model_label}...")
if args.load_existing:
Expand Down Expand Up @@ -211,15 +218,20 @@ def main(args):
)
else:
param_history_dict_negctrl = None
save_dict["data"] = ndata
# Save results

outfile_path = (
f"{prefix}/bean_element[sgRNA]_result.{model_label}{args.result_suffix}.csv"
)
info(f"Done running inference. Writing result at {outfile_path}...")
if not os.path.exists(prefix):
os.makedirs(prefix)
with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle:
pkl.dump(save_dict, handle)
if args.save_raw:
with open(
f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb"
) as handle:
pkl.dump(save_dict, handle)
write_result_table(
target_info_df,
param_history_dict,
Expand Down
5 changes: 5 additions & 0 deletions bean/mapping/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def _check_arguments(args, info_logger, warn_logger, error_logger):
sgRNA_info_tbl = pd.read_csv(args.sgRNA_filename)

def _check_sgrna_info_table(args, sgRNA_info_tbl):
# Check column names
if args.offset:
if "offset" not in sgRNA_info_tbl.columns:
raise InputFileError(
Expand All @@ -345,6 +346,10 @@ def _check_sgrna_info_table(args, sgRNA_info_tbl):
raise InputFileError(
f"Offset option is set but the input file doesn't contain the `reporter` column: Provided {sgRNA_info_tbl.columns}"
)
if sgRNA_info_tbl["name"].duplicated().any():
raise InputFileError(
f"Duplicate guide names: {sgRNA_info_tbl.loc[sgRNA_info_tbl['name'].duplicated(),:].index}. Please provide unique IDs for each guide."
)

_check_sgrna_info_table(args, sgRNA_info_tbl)

Expand Down
Loading

0 comments on commit b57f288

Please sign in to comment.