Skip to content

Commit

Permalink
debug multi-gene symbol input
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Oct 13, 2023
1 parent 196a386 commit 3205411
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 32 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ bean-filter my_sorting_screen.h5ad \
-o my_sorting_screen_masked.h5ad \
--translate `# Translate coding variants` \
[ --translate-gene-name GENE_SYMBOL OR
--translate-gene-names path_to_gene_names_file.txt OR
--translate-genes-list path_to_gene_names_file.txt OR
--translate-fasta gene_exon.fa, OR
--translate-fastas-csv gene_exon_fas.csv]
```
Expand All @@ -211,7 +211,9 @@ bean-filter my_sorting_screen.h5ad \
### Output
Above command produces
* `my_sorting_screen_filtered.h5ad` with filtered alleles stored in `.uns`,
* `my_sorting_screen_filtered.filtered_allele_stats.pdf`, and `my_sorting_screen_filtered.filter_log.txt` that report allele count stats in each filtering step.
* `my_sorting_screen_filtered.filtered_allele_stats.pdf`, and `my_sorting_screen_filtered.filter_log.txt` that report allele count stats in each filtering step.

You may want to adjust the flitering parameters to obtain optimal balance between # guides per variant & # variants that are scored. See example outputs of filtering step [here](docs/example_filtering_outputs/).


### Additional parameters
Expand All @@ -227,6 +229,8 @@ Above command produces
* `--translate` (default: `False`): Translate nucleotide-level variants prior to allele proportion filtering.
* `-f`, `--translate-fasta` (defulat: `None`): fasta file path with exon positions. If not provided and `--translate` flag is provided, LDLR hg19 coordinates will be used.
* `-fs`, `--translate-fastas-csv` (defulat: `None`): .csv with two columns with gene IDs and FASTA file path corresponding to each gene.
* `-g`, `--translate-gene-name` (default: `None`): Gene symbol for translation
* `-gs`, `--translate-genes-list` (default: `None`): Path to the text file with gene symbols in each line
* `-ap`, `--filter-allele-proportion` (default: `0.05`): If provided, only the alleles that exceed `filter_allele_proportion` in `filter-sample-proportion` will be retained.
* `-ac`, `--filter-allele-count` (default: `5`): If provided, alleles that exceed `filter_allele_proportion` AND `filter_allele_count` in `filter-sample-proportion` will be retained.
* `sp`, `--filter-sample-proportion` (default: `0.2`): "If `filter_allele_proportion` is provided, alleles that exceed `filter_allele_proportion` in `filter-sample-proportion` will be retained.
Expand Down
23 changes: 14 additions & 9 deletions bean/annotate/translate_allele.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,18 @@ class CDSCollection:

def __init__(
self,
gene_ids: List[str] = None,
gene_names: List[str] = None,
fasta_file_names: List[str] = None,
suppressMessage=True,
):
if fasta_file_names is None:
self.cds_dict = get_cds_dict(gene_ids)
elif len(gene_ids) != len(fasta_file_names):
raise ValueError("gene_ids and fasta_file_names have different lengths")
self.cds_dict = {}
for gid, fasta_file in zip(fasta_file_names, gene_ids):
self.cds_dict[gid] = CDS(fasta_file)
self.cds_dict = get_cds_dict(gene_names)
elif len(gene_names) != len(fasta_file_names):
raise ValueError("gene_names and fasta_file_names have different lengths")
else:
self.cds_dict = {}
for gid, fasta_file in zip(fasta_file_names, gene_names):
self.cds_dict[gid] = CDS(fasta_file)
self.cds_ranges = self.get_cds_ranges()

def get_cds_ranges(self):
Expand All @@ -325,8 +326,12 @@ def get_cds_ranges(self):
for gene_id, cds in self.cds_dict.items():
gids.append(gene_id)
seqnames.append(cds.chrom)
starts.append(cds.start)
ends.append(cds.end)
starts.append(cds.genomic_pos[0])
try:
ends.append(cds.genomic_pos[-1])
except:
print(cds.genomic_pos)
exit(1)

return pd.DataFrame(
{
Expand Down
23 changes: 17 additions & 6 deletions bean/annotate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ def get_mane_transcript_id(gene_name: str):
response = requests.get(api_url, headers={"Content-Type": "application/json"})
mane_json = response.json()
mane_df = pd.DataFrame.from_records(mane_json)
mane_transcript_id = mane_df.loc[
mane_df.ens_gene_name == gene_name, "ens_stable_id"
].values[0]
id_version = mane_df.loc[
mane_df.ens_gene_name == gene_name, "ens_stable_id_version"
].values[0]
try:
mane_transcript_id = mane_df.loc[
mane_df.ens_gene_name == gene_name, "ens_stable_id"
].values[0]
id_version = mane_df.loc[
mane_df.ens_gene_name == gene_name, "ens_stable_id_version"
].values[0]
except IndexError as e:
print(
f"Cannot find {gene_name} from MANE database: check http://tark.ensembl.org/api/transcript/manelist/ or use custom fasta."
)
exit(1)
return mane_transcript_id, id_version


Expand Down Expand Up @@ -313,6 +319,11 @@ def check_args(args):
raise ValueError(
"Invalid arguments: You should specify exactly one of --translate-fasta, --translate-fastas-csv, --translate-gene, translate-genes-list to translate alleles."
)
if args.translate_genes_list is not None:
args.translate_genes_list = (
pd.read_csv(args.translate_genes_list, header=None).values[:, 0].tolist()
)
info(f"Using {args.translate_genes_list} as genes for translation.")
if args.translate_fastas_csv:
tbl = pd.read_csv(
args.translate_fastas_csv,
Expand Down
13 changes: 12 additions & 1 deletion bean/framework/Edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self.ref_base = ref_base # TODO make it ref / alt instead of ref_base and alt_base for AAEdit comp. or make abstract class
self.alt_base = alt_base
self.uid = unique_identifier
if type(strand) == int:
if isinstance(strand, int):
strand_to_symbol = {1: "+", -1: "-"}
self.strand = strand_to_symbol[strand]
else:
Expand Down Expand Up @@ -92,6 +92,10 @@ def set_uid(self, uid):
self.uid = uid
return self

def set_chrom(self, chrom):
self.chrom = chrom
return self

def get_abs_base_change(self):
if self.strand == "-":
ref_base = type(self).reverse_map[self.ref_base]
Expand Down Expand Up @@ -135,6 +139,8 @@ def __init__(self, edits: Iterable[Edit] = None):
self.edits = set() if edits is None else set(edits)
if edits and len(edits) > 0:
self.chrom = next(iter(edits)).chrom
else:
self.chrom = None

@classmethod
def from_str(cls, allele_str): # pos:strand:start>end
Expand Down Expand Up @@ -212,11 +218,16 @@ def has_other_edit(self, ref_base, alt_base, pos=None, rel_pos=None):
return False

def get_jaccard(self, other):
if self.chrom != other.chrom:
return 0
return jaccard(set(map(str, self.edits)), set(map(str, other.edits)))

def get_jaccards(self, allele_list: Iterable[Allele]):
return np.array(list(map(lambda o: self.get_jaccard(o), allele_list)))

def set_chrom(self, chrom: str):
self.edits = {edit.set_chrom(chrom) for edit in self.edits}

def map_to_closest(
self, allele_list, jaccard_threshold=0.5, merge_priority: np.ndarray = None
):
Expand Down
4 changes: 2 additions & 2 deletions bean/mapping/GuideEditCounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def __init__(self, **kwargs):
self.align_score_threshold = 80
self.target_pos_col = kwargs["target_pos_col"]

self.guide_start_seq = kwargs["guide_start_seq"]
self.guide_end_seq = kwargs["guide_end_seq"]
self.guide_start_seq = kwargs["guide_start_seq"].upper()
self.guide_end_seq = kwargs["guide_end_seq"].upper()
if not self.guide_start_seq == "":
info(
f"{self.name}: Using guide_start_seq={self.guide_start_seq} for {self.output_dir}"
Expand Down
26 changes: 22 additions & 4 deletions bean/plotting/allele_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def plot_n_alleles_per_guide(
for g in bdata.guides.index
]
ax.hist(lens, bins=np.arange(min(lens) - 0.5, max(lens) + 0.5))
ax.set_xlabel("# alleles")
ax.set_title(f"# Alleles per guide, raw (n={len(bdata.uns[allele_df_key])})")
ax.set_xlabel("# alleles per guide")
ax.set_title(f"n_alleles={len(bdata.uns[allele_df_key])}")
ax.set_ylabel("# guides")
return ax

Expand All @@ -40,7 +40,25 @@ def plot_n_guides_per_edit(
if ax is None:
fig, ax = plt.subplots()
ax.hist(n_guides, bins=np.arange(min(n_guides) - 0.5, max(n_guides) + 0.5))
ax.set_xlabel("# guides")
ax.set_title(f"# guides per edit (n={len(edits_df)})")
ax.set_xlabel("# guides per variant")
ax.set_title(f"n_variants={len(edits_df)}")
ax.set_ylabel("# edits")
return ax


def plot_allele_stats(bdata, allele_df_keys, plot_save_path):
fig = plt.figure(constrained_layout=True, figsize=(6, 3 * len(allele_df_keys)))
fig.suptitle("Allele stats")

subfigs = fig.subfigures(nrows=len(allele_df_keys), ncols=1)
for row, subfig in enumerate(subfigs):
key = allele_df_keys[row]
subfig.suptitle(f"{key}")

# create 1x3 subplots per subfig
ax = subfig.subplots(nrows=1, ncols=2)
plot_n_alleles_per_guide(bdata, key, bdata.uns[key].columns[1], ax[0])
plot_n_guides_per_edit(bdata, key, bdata.uns[key].columns[1], ax[1])

#plt.tight_layout()
fig.savefig(plot_save_path, bbox_inches="tight")
16 changes: 8 additions & 8 deletions bin/bean-filter
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import sys
import logging
import pandas as pd
import bean as be
from bean.plotting.allele_stats import plot_n_alleles_per_guide, plot_n_guides_per_edit
from bean.plotting.allele_stats import (
plot_n_alleles_per_guide,
plot_n_guides_per_edit,
plot_allele_stats,
)
from bean.annotate.utils import parse_args, check_args
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -147,13 +151,9 @@ if __name__ == "__main__":
bdata.write(f"{args.output_prefix}.h5ad")

info("Plotting allele stats for each filtering step...")
fig, ax = plt.subplots(len(allele_df_keys), 2, figsize=(6, 3 * len(allele_df_keys)))
for i, key in enumerate(allele_df_keys):
if len(bdata.uns[key]) > 0:
plot_n_alleles_per_guide(bdata, key, bdata.uns[key].columns[1], ax[i, 0])
plot_n_guides_per_edit(bdata, key, bdata.uns[key].columns[1], ax[i, 1])
plt.tight_layout()
plt.savefig(f"{args.output_prefix}.filtered_allele_stats.pdf", bbox_inches="tight")
plot_allele_stats(
bdata, allele_df_keys, f"{args.output_prefix}.filtered_allele_stats.pdf"
)
info(
f"Saving plotting result and log at {args.output_prefix}.[filtered_allele_stats.pdf, filter_log.txt]."
)
Expand Down

0 comments on commit 3205411

Please sign in to comment.