Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proliferation model (initial version) and improved mapping #44

Merged
merged 33 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
13cbd6d
clearer error message in QC
jykr Jul 2, 2024
8b3d7bc
adding per-replicate LFC and total editing rate
jykr Aug 14, 2024
cd294b6
output editing rate also for uniform mode
jykr Aug 14, 2024
86572bc
remove reordering of columns to troubleshoot dataset without editing …
jykr Aug 14, 2024
2fdb440
adjustments for survival models
jykr Aug 15, 2024
c7b7f54
debug non-numeric error for time column
jykr Aug 16, 2024
e5b3a90
allow user-provided control-conditions
jykr Aug 23, 2024
3b40968
debug errors in calculating edits
jykr Aug 26, 2024
c5ac7ad
allow 0 edit counts
jykr Aug 26, 2024
ac070b6
debug sample ordering
jykr Aug 26, 2024
8ad2475
twoctrls6
jykr Aug 27, 2024
cd36cce
twoctrls7
jykr Aug 27, 2024
b50ecf2
twoctrls8
jykr Aug 27, 2024
793c623
twoctrls9
jykr Aug 27, 2024
21db8f4
twoctrls10
jykr Aug 27, 2024
d526512
twoctrls11
jykr Aug 28, 2024
5359f46
set correct mu0 for nonedited
jykr Aug 29, 2024
a8b31dd
don't allow training of mu0 with non-negctrl
jykr Aug 29, 2024
628636e
model16
jykr Aug 29, 2024
7d41341
model18
jykr Aug 29, 2024
112ab89
model20
jykr Aug 29, 2024
d5572b9
model21
jykr Aug 29, 2024
71a95e4
model24
jykr Sep 4, 2024
1c84b97
model25: normalize by total abundance
jykr Sep 4, 2024
4722d40
model26
jykr Sep 4, 2024
883bcd9
model29
jykr Sep 5, 2024
5b9ea5f
model30: should I double the norm factor?
jykr Sep 5, 2024
6b4d8fd
model31
jykr Sep 5, 2024
acac2ad
model31
jykr Sep 6, 2024
8125c26
tiling screen update
jykr Oct 2, 2024
54f0026
filter out R2 with too short reporter length
jykr Oct 15, 2024
80c7afc
allow duplicated matching to the best match
jykr Nov 1, 2024
ba0d9c3
pin docutils version for Sphinx
jykr Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:

- name: Install Sphinx & Dependencies
run: |
pip install sphinx sphinx_markdown_builder sphinx_rtd_theme sphinx-argparse m2r pandas bio
pip install sphinx sphinx_markdown_builder sphinx_rtd_theme sphinx-argparse m2r pandas bio "docutils==0.20"
sudo apt-get install python3-distutils
- name: Build Documentation
working-directory: docs
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
2. [`profile`](https://pinellolab.github.io/crispr-bean/profile.html): Profile editing preferences of your editor.
3. [`qc`](https://pinellolab.github.io/crispr-bean/qc.html): Quality control report and filtering out / masking of aberrant sample and guides
4. [`filter`](https://pinellolab.github.io/crispr-bean/filter.html): Filter reporter alleles; essential for `tiling` mode that allows for all alleles generated from gRNA.
5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. **See more about the [model](https://pinellolab.github.io/crispr-bean/model.html) & [output](https://github.com/pinellolab/crispr-bean/tree/main/docs/example_run_output)**
5. [`run`](https://pinellolab.github.io/crispr-bean/run.html): Quantify targeted variants' effect sizes from screen data. **See more about the [model](https://pinellolab.github.io/crispr-bean/model.html) & [output](https://pinellolab.github.io/crispr-bean/run.html#output)**
* Screen data is saved as [`ReporterScreen` object](https://pinellolab.github.io/crispr-bean/reporterscreen.html) in the pipeline.
BEAN stores mapped gRNA and allele counts in `ReporterScreen` object which is compatible with [AnnData](https://anndata.readthedocs.io/en/latest/index.html).

Expand Down
10 changes: 10 additions & 0 deletions bean/annotate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def check_args(args):
raise ValueError(
"Invalid arguments: You should specify exactly one of --translate-fasta, --translate-fastas-csv, --translate-gene, translate-genes-list to translate alleles."
)
if (
args.translate_fasta is not None
or args.translate_fastas_csv is not None
or args.translate_gene is not None
or args.translate_genes_list is not None
) and not args.translate:
warn(
"fastq or gene files for translation provided without `--translate` flag. Setting `--translate` flag to True."
)
args.translate = True
if args.translate_genes_list is not None:
args.translate_genes_list = (
pd.read_csv(args.translate_genes_list, header=None).values[:, 0].tolist()
Expand Down
3 changes: 3 additions & 0 deletions bean/cli/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def main(args):
counter.screen.uns["allele_counts"] = counter.screen.uns["allele_counts"].loc[
counter.screen.uns["allele_counts"].allele.map(str) != "", :
]
counter.screen.get_edit_from_allele("allele_counts", "allele")
if match_target_pos:
counter.screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
else:
counter.screen.get_edit_mat_from_uns(target_base_edits)
counter.screen.write(f"{counter.output_dir}.h5ad")
counter.screen.to_Excel(f"{counter.output_dir}.xlsx")
info(f"Output written at:\n {counter.output_dir}.h5ad,\n {counter.output_dir}.xlsx")
Expand Down
11 changes: 8 additions & 3 deletions bean/cli/count_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace):
except KeyError as e:
raise KeyError(args_dict["edited_base"]) from e

match_target_pos = args_dict["match_target_pos"]
match_target_pos = (
args_dict["match_target_pos"] if not args_dict["tiling"] else False
)
if (
"guide_start_seqs_tbl" in args_dict
and args_dict["guide_start_seqs_tbl"] is not None
Expand Down Expand Up @@ -92,12 +94,15 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace):
screen = counter.screen
if screen.X.max() == 0:
warn(f"Nothing counted for {sample_id}. Check your input.")
if counter.count_reporter_edits and match_target_pos:
if counter.count_reporter_edits:
screen.uns["allele_counts"] = screen.uns["allele_counts"].loc[
screen.uns["allele_counts"].allele.map(str) != "", :
]
screen.get_edit_from_allele("allele_counts", "allele")
screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
if match_target_pos:
screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
else:
screen.get_edit_mat_from_uns(target_base_edits)
info(
f"Done for {sample_id}. \n\
Output written at {counter.output_dir}.h5ad"
Expand Down
47 changes: 40 additions & 7 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def main(args, return_data=False):
file_logger = logging.FileHandler(f"{prefix}/bean_run.log")
file_logger.setLevel(logging.INFO)
logging.getLogger().addHandler(file_logger)
info(f"Running: {' '.join(sys.argv[:])}")
if args.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
Expand All @@ -96,6 +96,10 @@ def main(args, return_data=False):

# Format bdata into data structure compatible with Pyro model
bdata = prepare_bdata(bdata, args, warn, prefix)
negctrl_idx = np.where(
bdata.guides[args.negctrl_col].map(lambda s: s.lower())
== args.negctrl_col_value.lower()
)[0]
ndata = DATACLASS_DICT[args.selection][model_label](
screen=bdata,
device=args.device,
Expand All @@ -115,20 +119,24 @@ def main(args, return_data=False):
popt=args.popt,
replicate_col=args.replicate_col,
use_bcmatch=(not args.ignore_bcmatch),
negctrl_guide_idx=negctrl_idx,
)
if args.save_raw:
pkl.dump(bdata, open(f"{prefix}/ndata.pkl", "wb"))
guide_index = ndata.screen.guides.index.copy()
assert len(guide_index) == bdata.n_obs, (len(guide_index), bdata.n_obs)
if return_data:
return ndata
# Build variant dataframe
adj_negctrl_idx = None
_control_condition = args.control_condition.split(",")[0]
if args.library_design == "variant":
if not args.uniform_edit:
if "edit_rate" not in ndata.screen.guides.columns:
ndata.screen.get_edit_from_allele()
ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True)
ndata.screen.get_guide_edit_rate(
unsorted_condition_label=args.control_condition
unsorted_condition_label=_control_condition
)
target_info_df = _get_guide_target_info(
ndata.screen, args, cols_include=[args.negctrl_col]
Expand All @@ -139,6 +147,12 @@ def main(args, return_data=False):
== args.negctrl_col_value.lower()
)[0]
else:
if "edit_rate_norm" not in ndata.screen.guides.columns:
ndata.screen.get_edit_from_allele()
ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True)
ndata.screen.get_guide_edit_rate(
unsorted_condition_label=_control_condition
)
if args.splice_site_path is not None:
splice_site = pd.read_csv(args.splice_site_path).pos
target_info_df = be.an.translate_allele.annotate_edit(
Expand Down Expand Up @@ -198,7 +212,6 @@ def main(args, return_data=False):
],
axis=1,
)

# Add user-defined prior.
if args.prior_params is not None:
prior_params = _check_prior_params(args.prior_params, ndata)
Expand All @@ -210,13 +223,12 @@ def main(args, return_data=False):
with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle:
param_history_dict = pkl.load(handle)
else:
param_history_dict, save_dict = deepcopy(
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
save_dict = dict()
if args.fit_negctrl:
negctrl_model, negctrl_guide = identify_negctrl_model_guide(
args, "X_bcmatch" in bdata.layers
)
print(f"Using {negctrl_model} & {negctrl_guide}")
negctrl_idx = np.where(
ndata.screen.guides[args.negctrl_col].map(lambda s: s.lower())
== args.negctrl_col_value.lower()
Expand All @@ -225,16 +237,36 @@ def main(args, return_data=False):
f"Using {len(negctrl_idx)} negative control elements to adjust phenotypic effect sizes..."
)
ndata_negctrl = ndata[negctrl_idx]
print(
f"ndata size factor: {ndata.size_factor}, {ndata_negctrl.size_factor}\esf"
)
if args.save_raw:
pkl.dump(ndata_negctrl, open(f"{prefix}/ndata_negctrl.pkl", "wb"))
param_history_dict_negctrl, save_dict["negctrl"] = deepcopy(
run_inference(
negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter
)
)
if args.selection == "survival":
print(
f"Feeding mu_negctrl={param_history_dict_negctrl['mu_loc'], param_history_dict_negctrl['mu_scale']} into model..."
)
model = partial(
model,
mu_negctrl=(
param_history_dict_negctrl["mu_loc"].detach().mean(),
param_history_dict_negctrl["mu_scale"].detach().mean(),
),
)
else:
param_history_dict_negctrl = None
save_dict["data"] = ndata
param_history_dict, save_dict_model = deepcopy(
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
for k, v in save_dict_model.items():
save_dict[k] = v
# Save results

outfile_path = (
f"{prefix}/bean_element[sgRNA]_result.{model_label}{args.result_suffix}.csv"
)
Expand Down Expand Up @@ -266,5 +298,6 @@ def main(args, return_data=False):
sample_covariates=(
ndata.sample_covariates if hasattr(ndata, "sample_covariates") else None
),
is_survival_screen=(args.selection == "survival"),
)
info("Done!")
24 changes: 15 additions & 9 deletions bean/framework/ReporterScreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ def get_edit_mat_from_uns(
reporter_length = self.uns["reporter_length"]
if "reproter_right_flank_length" in self.uns:
reporter_right_flank_length = self.uns["reporter_right_flank_length"]
if edit_count_key not in self.uns or len(self.uns[edit_count_key]) == 0:
if edit_count_key not in self.uns:
raise ValueError(
"Edit count isn't calculated. "
f"Edit count isn't calculated or not provided with specified key `{edit_count_key}`. "
+ "Call .get_edit_from_allele(allele_count_key, allele_key)"
)
edits = self.uns[edit_count_key].copy()
Expand All @@ -410,7 +410,7 @@ def get_edit_mat_from_uns(
edits["guide_start_pos"] = (
reporter_length
- reporter_right_flank_length
- guide_len[edits.guide_idx].reset_index(drop=True)
- guide_len.iloc[edits.guide_idx].reset_index(drop=True)
)
if not match_target_position:
edits["rel_pos"] = edits.edit.map(lambda e: e.rel_pos)
Expand All @@ -434,12 +434,14 @@ def get_edit_mat_from_uns(
edits.target_pos_matches,
["guide", "edit"] + self.samples.index.tolist(),
]

good_edits = good_edits.copy()
good_guide_idx = guide_name_to_idx.loc[good_edits.guide, "index"].astype(int)
edit_mat = np.zeros(self.layers["edits"].shape)
for gidx, eidx in zip(good_guide_idx, good_edits.index):
self.layers["edits"][gidx, :] += good_edits.loc[
edit_mat[gidx, :] = edit_mat[gidx, :] + good_edits.loc[
eidx, self.samples.index.tolist()
].astype(int)
self.layers["edits"] = edit_mat
print("New edit matrix saved in .layers['edits']. Returning old edits.")
return old_edits

Expand Down Expand Up @@ -505,17 +507,21 @@ def get_guide_edit_rate(
prior_weight = 1
n_edits = self.layers[edit_layer].copy()[:, bulk_idx].sum(axis=1)
n_counts = self.layers[count_layer].copy()[:, bulk_idx].sum(axis=1)
edit_rate = (n_edits + prior_weight / 2) / (
(n_counts * num_targetable_sites) + prior_weight / 2
)
edit_rate = (n_edits + prior_weight / 2) / ((n_counts) + prior_weight / 2)
if normalize_by_editable_base:
edit_rate_norm = (n_edits + prior_weight / 2) / (
(n_counts * num_targetable_sites) + prior_weight / 2
)
edit_rate[n_counts < bcmatch_thres] = np.nan
if normalize_by_editable_base:
print("normalize by editable counts")
edit_rate[num_targetable_sites == 0] = np.nan
edit_rate_norm[num_targetable_sites == 0] = np.nan
if return_result:
return edit_rate
else:
self.guides["edit_rate"] = edit_rate
if normalize_by_editable_base:
self.guides["edit_rate_norm"] = edit_rate_norm
print(self.guides.edit_rate)

def get_edit_rate(
Expand Down
Loading
Loading