Skip to content

Commit

Permalink
Merge pull request #44 from pinellolab/dev
Browse files Browse the repository at this point in the history
Proliferation model (initial version) and improved mapping
  • Loading branch information
jykr authored Nov 1, 2024
2 parents b0a2d1b + ba0d9c3 commit 992f4dc
Show file tree
Hide file tree
Showing 24 changed files with 11,688 additions and 11,284 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:

- name: Install Sphinx & Dependencies
run: |
pip install sphinx sphinx_markdown_builder sphinx_rtd_theme sphinx-argparse m2r pandas bio
pip install sphinx sphinx_markdown_builder sphinx_rtd_theme sphinx-argparse m2r pandas bio "docutils==0.20"
sudo apt-get install python3-distutils
- name: Build Documentation
working-directory: docs
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
2. [`profile`](https://pinellolab.github.io/crispr-bean/profile.html): Profile editing preferences of your editor.
3. [`qc`](https://pinellolab.github.io/crispr-bean/qc.html): Quality control report and filtering out / masking of aberrant sample and guides
4. [`filter`](https://pinellolab.github.io/crispr-bean/filter.html): Filter reporter alleles; essential for `tiling` mode that allows for all alleles generated from gRNA.
5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. **See more about the [model](https://pinellolab.github.io/crispr-bean/model.html) & [output](https://github.com/pinellolab/crispr-bean/tree/main/docs/example_run_output)**
5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. **See more about the [model](https://pinellolab.github.io/crispr-bean/model.html) & [output](https://pinellolab.github.io/crispr-bean/run.html#output)**
* Screen data is saved as [`ReporterScreen` object](https://pinellolab.github.io/crispr-bean/reporterscreen.html) in the pipeline.
BEAN stores mapped gRNA and allele counts in `ReporterScreen` object which is compatible with [AnnData](https://anndata.readthedocs.io/en/latest/index.html).

Expand Down
10 changes: 10 additions & 0 deletions bean/annotate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def check_args(args):
raise ValueError(
"Invalid arguments: You should specify exactly one of --translate-fasta, --translate-fastas-csv, --translate-gene, translate-genes-list to translate alleles."
)
if (
args.translate_fasta is not None
or args.translate_fastas_csv is not None
or args.translate_gene is not None
or args.translate_genes_list is not None
) and not args.translate:
warn(
"fastq or gene files for translation provided without `--translate` flag. Setting `--translate` flag to True."
)
args.translate = True
if args.translate_genes_list is not None:
args.translate_genes_list = (
pd.read_csv(args.translate_genes_list, header=None).values[:, 0].tolist()
Expand Down
3 changes: 3 additions & 0 deletions bean/cli/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def main(args):
counter.screen.uns["allele_counts"] = counter.screen.uns["allele_counts"].loc[
counter.screen.uns["allele_counts"].allele.map(str) != "", :
]
counter.screen.get_edit_from_allele("allele_counts", "allele")
if match_target_pos:
counter.screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
else:
counter.screen.get_edit_mat_from_uns(target_base_edits)
counter.screen.write(f"{counter.output_dir}.h5ad")
counter.screen.to_Excel(f"{counter.output_dir}.xlsx")
info(f"Output written at:\n {counter.output_dir}.h5ad,\n {counter.output_dir}.xlsx")
Expand Down
11 changes: 8 additions & 3 deletions bean/cli/count_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace):
except KeyError as e:
raise KeyError(args_dict["edited_base"]) from e

match_target_pos = args_dict["match_target_pos"]
match_target_pos = (
args_dict["match_target_pos"] if not args_dict["tiling"] else False
)
if (
"guide_start_seqs_tbl" in args_dict
and args_dict["guide_start_seqs_tbl"] is not None
Expand Down Expand Up @@ -92,12 +94,15 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace):
screen = counter.screen
if screen.X.max() == 0:
warn(f"Nothing counted for {sample_id}. Check your input.")
if counter.count_reporter_edits and match_target_pos:
if counter.count_reporter_edits:
screen.uns["allele_counts"] = screen.uns["allele_counts"].loc[
screen.uns["allele_counts"].allele.map(str) != "", :
]
screen.get_edit_from_allele("allele_counts", "allele")
screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
if match_target_pos:
screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
else:
screen.get_edit_mat_from_uns(target_base_edits)
info(
f"Done for {sample_id}. \n\
Output written at {counter.output_dir}.h5ad"
Expand Down
47 changes: 40 additions & 7 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def main(args, return_data=False):
file_logger = logging.FileHandler(f"{prefix}/bean_run.log")
file_logger.setLevel(logging.INFO)
logging.getLogger().addHandler(file_logger)
info(f"Running: {' '.join(sys.argv[:])}")
if args.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
Expand All @@ -96,6 +96,10 @@ def main(args, return_data=False):

# Format bdata into data structure compatible with Pyro model
bdata = prepare_bdata(bdata, args, warn, prefix)
negctrl_idx = np.where(
bdata.guides[args.negctrl_col].map(lambda s: s.lower())
== args.negctrl_col_value.lower()
)[0]
ndata = DATACLASS_DICT[args.selection][model_label](
screen=bdata,
device=args.device,
Expand All @@ -115,20 +119,24 @@ def main(args, return_data=False):
popt=args.popt,
replicate_col=args.replicate_col,
use_bcmatch=(not args.ignore_bcmatch),
negctrl_guide_idx=negctrl_idx,
)
if args.save_raw:
pkl.dump(bdata, open(f"{prefix}/ndata.pkl", "wb"))
guide_index = ndata.screen.guides.index.copy()
assert len(guide_index) == bdata.n_obs, (len(guide_index), bdata.n_obs)
if return_data:
return ndata
# Build variant dataframe
adj_negctrl_idx = None
_control_condition = args.control_condition.split(",")[0]
if args.library_design == "variant":
if not args.uniform_edit:
if "edit_rate" not in ndata.screen.guides.columns:
ndata.screen.get_edit_from_allele()
ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True)
ndata.screen.get_guide_edit_rate(
unsorted_condition_label=args.control_condition
unsorted_condition_label=_control_condition
)
target_info_df = _get_guide_target_info(
ndata.screen, args, cols_include=[args.negctrl_col]
Expand All @@ -139,6 +147,12 @@ def main(args, return_data=False):
== args.negctrl_col_value.lower()
)[0]
else:
if "edit_rate_norm" not in ndata.screen.guides.columns:
ndata.screen.get_edit_from_allele()
ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True)
ndata.screen.get_guide_edit_rate(
unsorted_condition_label=_control_condition
)
if args.splice_site_path is not None:
splice_site = pd.read_csv(args.splice_site_path).pos
target_info_df = be.an.translate_allele.annotate_edit(
Expand Down Expand Up @@ -198,7 +212,6 @@ def main(args, return_data=False):
],
axis=1,
)

# Add user-defined prior.
if args.prior_params is not None:
prior_params = _check_prior_params(args.prior_params, ndata)
Expand All @@ -210,13 +223,12 @@ def main(args, return_data=False):
with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle:
param_history_dict = pkl.load(handle)
else:
param_history_dict, save_dict = deepcopy(
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
save_dict = dict()
if args.fit_negctrl:
negctrl_model, negctrl_guide = identify_negctrl_model_guide(
args, "X_bcmatch" in bdata.layers
)
print(f"Using {negctrl_model} & {negctrl_guide}")
negctrl_idx = np.where(
ndata.screen.guides[args.negctrl_col].map(lambda s: s.lower())
== args.negctrl_col_value.lower()
Expand All @@ -225,16 +237,36 @@ def main(args, return_data=False):
f"Using {len(negctrl_idx)} negative control elements to adjust phenotypic effect sizes..."
)
ndata_negctrl = ndata[negctrl_idx]
print(
f"ndata size factor: {ndata.size_factor}, {ndata_negctrl.size_factor}\esf"
)
if args.save_raw:
pkl.dump(ndata_negctrl, open(f"{prefix}/ndata_negctrl.pkl", "wb"))
param_history_dict_negctrl, save_dict["negctrl"] = deepcopy(
run_inference(
negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter
)
)
if args.selection == "survival":
print(
f"Feeding mu_negctrl={param_history_dict_negctrl['mu_loc'], param_history_dict_negctrl['mu_scale']} into model..."
)
model = partial(
model,
mu_negctrl=(
param_history_dict_negctrl["mu_loc"].detach().mean(),
param_history_dict_negctrl["mu_scale"].detach().mean(),
),
)
else:
param_history_dict_negctrl = None
save_dict["data"] = ndata
param_history_dict, save_dict_model = deepcopy(
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
for k, v in save_dict_model.items():
save_dict[k] = v
# Save results

outfile_path = (
f"{prefix}/bean_element[sgRNA]_result.{model_label}{args.result_suffix}.csv"
)
Expand Down Expand Up @@ -266,5 +298,6 @@ def main(args, return_data=False):
sample_covariates=(
ndata.sample_covariates if hasattr(ndata, "sample_covariates") else None
),
is_survival_screen=(args.selection == "survival"),
)
info("Done!")
24 changes: 15 additions & 9 deletions bean/framework/ReporterScreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ def get_edit_mat_from_uns(
reporter_length = self.uns["reporter_length"]
if "reproter_right_flank_length" in self.uns:
reporter_right_flank_length = self.uns["reporter_right_flank_length"]
if edit_count_key not in self.uns or len(self.uns[edit_count_key]) == 0:
if edit_count_key not in self.uns:
raise ValueError(
"Edit count isn't calculated. "
f"Edit count isn't calculated or not provided with specified key `{edit_count_key}`. "
+ "Call .get_edit_from_allele(allele_count_key, allele_key)"
)
edits = self.uns[edit_count_key].copy()
Expand All @@ -410,7 +410,7 @@ def get_edit_mat_from_uns(
edits["guide_start_pos"] = (
reporter_length
- reporter_right_flank_length
- guide_len[edits.guide_idx].reset_index(drop=True)
- guide_len.iloc[edits.guide_idx].reset_index(drop=True)
)
if not match_target_position:
edits["rel_pos"] = edits.edit.map(lambda e: e.rel_pos)
Expand All @@ -434,12 +434,14 @@ def get_edit_mat_from_uns(
edits.target_pos_matches,
["guide", "edit"] + self.samples.index.tolist(),
]

good_edits = good_edits.copy()
good_guide_idx = guide_name_to_idx.loc[good_edits.guide, "index"].astype(int)
edit_mat = np.zeros(self.layers["edits"].shape)
for gidx, eidx in zip(good_guide_idx, good_edits.index):
self.layers["edits"][gidx, :] += good_edits.loc[
edit_mat[gidx, :] = edit_mat[gidx, :] + good_edits.loc[
eidx, self.samples.index.tolist()
].astype(int)
self.layers["edits"] = edit_mat
print("New edit matrix saved in .layers['edits']. Returning old edits.")
return old_edits

Expand Down Expand Up @@ -505,17 +507,21 @@ def get_guide_edit_rate(
prior_weight = 1
n_edits = self.layers[edit_layer].copy()[:, bulk_idx].sum(axis=1)
n_counts = self.layers[count_layer].copy()[:, bulk_idx].sum(axis=1)
edit_rate = (n_edits + prior_weight / 2) / (
(n_counts * num_targetable_sites) + prior_weight / 2
)
edit_rate = (n_edits + prior_weight / 2) / ((n_counts) + prior_weight / 2)
if normalize_by_editable_base:
edit_rate_norm = (n_edits + prior_weight / 2) / (
(n_counts * num_targetable_sites) + prior_weight / 2
)
edit_rate[n_counts < bcmatch_thres] = np.nan
if normalize_by_editable_base:
print("normalize by editable counts")
edit_rate[num_targetable_sites == 0] = np.nan
edit_rate_norm[num_targetable_sites == 0] = np.nan
if return_result:
return edit_rate
else:
self.guides["edit_rate"] = edit_rate
if normalize_by_editable_base:
self.guides["edit_rate_norm"] = edit_rate_norm
print(self.guides.edit_rate)

def get_edit_rate(
Expand Down
Loading

0 comments on commit 992f4dc

Please sign in to comment.