Skip to content

Commit

Permalink
tiling screen update
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Oct 2, 2024
1 parent acac2ad commit 8125c26
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 105 deletions.
6 changes: 3 additions & 3 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def main(args, return_data=False):
return ndata
# Build variant dataframe
adj_negctrl_idx = None

_control_condition = args.control_condition.split(",")[0]
if args.library_design == "variant":
if not args.uniform_edit:
if "edit_rate" not in ndata.screen.guides.columns:
ndata.screen.get_edit_from_allele()
ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True)
ndata.screen.get_guide_edit_rate(
unsorted_condition_label=args.control_condition
unsorted_condition_label=_control_condition
)
target_info_df = _get_guide_target_info(
ndata.screen, args, cols_include=[args.negctrl_col]
Expand All @@ -151,7 +151,7 @@ def main(args, return_data=False):
ndata.screen.get_edit_from_allele()
ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True)
ndata.screen.get_guide_edit_rate(
unsorted_condition_label=args.control_condition
unsorted_condition_label=_control_condition
)
if args.splice_site_path is not None:
splice_site = pd.read_csv(args.splice_site_path).pos
Expand Down
22 changes: 13 additions & 9 deletions bean/mapping/GuideEditCounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from bean import Allele, ReporterScreen
from Bio import SeqIO
from Bio.SeqIO.QualityIO import FastqPhredIterator

if sys.stderr.isatty():
# Output into terminal
from tqdm import tqdm
Expand All @@ -19,6 +20,7 @@
def tqdm(iterable, **kwargs):
return iterable


from ._supporting_fn import (
_base_edit_to_from,
_get_edited_allele_crispresso,
Expand Down Expand Up @@ -119,7 +121,11 @@ def __init__(self, **kwargs):
)
self.screen.guides["guide_len"] = self.screen.guides.sequence.map(len)
self.screen.uns["reporter_length"] = kwargs["reporter_length"]
self.screen.uns["reporter_right_flank_length"] = kwargs["reporter_length"] - kwargs["gstart_reporter"] - self.screen.guides["guide_len"].max()
self.screen.uns["reporter_right_flank_length"] = (
kwargs["reporter_length"]
- kwargs["gstart_reporter"]
- self.screen.guides["guide_len"].max()
)
self.count_guide_edits = kwargs["count_guide_edits"]
if self.count_guide_edits:
self.screen.uns["guide_edit_counts"] = {}
Expand Down Expand Up @@ -387,8 +393,8 @@ def _count_guide_edits(
R1_record, len(ref_guide_seq)
)
guide_edit_allele, score = _get_edited_allele_crispresso(
ref_seq=ref_guide_seq,
query_seq=read_guide_seq,
ref_seq=ref_guide_seq.upper(),
query_seq=read_guide_seq.upper(),
target_base_edits=self.target_base_edits,
aln_mat_path=self.output_dir + "/.aln_mat.txt",
offset=0,
Expand Down Expand Up @@ -506,8 +512,8 @@ def _count_reporter_edits(
else:
chrom = None
allele, score = _get_edited_allele_crispresso(
ref_seq=ref_reporter_seq,
query_seq=read_reporter_seq,
ref_seq=ref_reporter_seq.upper(),
query_seq=read_reporter_seq.upper(),
target_base_edits=self.target_base_edits,
aln_mat_path=self.output_dir + "/.aln_mat.txt",
offset=offset,
Expand Down Expand Up @@ -547,7 +553,7 @@ def _get_guide_counts_bcmatch_semimatch(
"duplicate_wo_barcode"
)
outfile_R1_dup, outfile_R2_dup = self._get_fastq_handle("duplicate")
tqdm_reads= tqdm(
tqdm_reads = tqdm(
enumerate(zip(R1_iter, R2_iter)),
total=self.n_reads_after_filtering,
postfix=f"n_read={self.bcmatch}",
Expand Down Expand Up @@ -578,9 +584,7 @@ def _get_guide_counts_bcmatch_semimatch(
matched_guide_idx = semimatch[0]
self.screen.layers[semimatch_layer][matched_guide_idx, 0] += 1
if self.count_guide_edits:
guide_allele, _ = self._count_guide_edits(
matched_guide_idx, r1
)
guide_allele, _ = self._count_guide_edits(matched_guide_idx, r1)
self.semimatch += 1

elif len(bc_match) >= 2: # duplicate mapping
Expand Down
162 changes: 69 additions & 93 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,9 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True):
replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2)
time_plate = pyro.plate("time_plate", data.n_condits, dim=-2)
guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1)
# Set the prior for phenotype means
# print(f"why? {data.n_targets}, {data.target_lengths.shape}")
with pyro.plate("target_plate", data.n_targets):
mu_targets = pyro.sample("mu_targets", dist.Normal(0, 1))
mu = torch.repeat_interleave(mu_targets, data.target_lengths)

mu_targets = pyro.sample("mu_targets", dist.Normal(0, 1))
mu = mu_targets.repeat(data.n_guides)
with replicate_plate:
with time_plate as t:
time = data.timepoints[t]
Expand Down Expand Up @@ -361,14 +359,6 @@ def MixtureNormalModel(
mu.unsqueeze(0).expand((data.n_condits, -1, -1))
* time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)),
)
# negctrl_abundance = pyro.param(
# "negctrl_abundance",
# torch.ones((data.n_condits,)),
# constraint=constraints.positive,
# )
# alleles_p_time = (
# alleles_p_time / negctrl_abundance.clamp(min=1e-5)[:, None, None]
# )
assert alleles_p_time.shape == (data.n_condits, data.n_guides, 2)

expected_allele_p = (
Expand Down Expand Up @@ -445,6 +435,7 @@ def MultiMixtureNormalModel(
fit_noise: bool = False,
prior_params: Optional[dict] = None,
epsilon=1e-5,
mu_negctrl: float = (0.0, 0.1),
):
"""
Using the reporter outcome, phenotype of cells with a guide will be modeled as mixture of normal distributions of all major alleles (including WT) produced by the guide.
Expand All @@ -467,7 +458,7 @@ def MultiMixtureNormalModel(
guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1)

mu_dist = dist.Laplace(0, 1)
initial_abundance = torch.ones(data.n_guides) / data.n_guides
# initial_abundance = torch.ones(data.n_guides) / data.n_guides
if prior_params is not None:
if "mu_loc" in prior_params or "mu_scale" in prior_params:
mu_loc = 0.0
Expand All @@ -477,8 +468,8 @@ def MultiMixtureNormalModel(
if "mu_scale" in prior_params:
mu_scale = prior_params["mu_scale"]
mu_dist = dist.Normal(mu_loc, mu_scale)
if "initial_abundance" in prior_params:
initial_abundance = prior_params["initial_abundance"]
# if "initial_abundance" in prior_params:
# initial_abundance = prior_params["initial_abundance"]

# Set the prior for phenotype means
with pyro.plate("guide_plate1", data.n_edits):
Expand All @@ -489,17 +480,18 @@ def MultiMixtureNormalModel(
data.n_max_alleles - 1,
data.n_edits,
)

mu_targets = torch.matmul(data.allele_to_edit, mu_edits)
assert mu_targets.shape == (data.n_guides, data.n_max_alleles - 1)

mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_targets], axis=-1)
r = torch.exp(mu)

with pyro.plate("replicate_plate0", data.n_reps, dim=-1):
q_0 = pyro.sample(
"initial_guide_abundance",
dist.Dirichlet(initial_abundance.unsqueeze(0).expand(data.n_reps, -1)),
with pyro.plate("guide_plate_0", data.n_guides):
mu_guide_unedited = pyro.sample(
"mu_negctrl", dist.Normal(mu_negctrl[0], mu_negctrl[1])
)
mu = torch.cat(
[mu_guide_unedited.unsqueeze(-1), mu_guide_unedited.unsqueeze(-1) + mu_targets],
axis=1,
)

# The pi should be Dirichlet distributed instead of independent betas
alpha_pi0 = (
torch.ones(
Expand Down Expand Up @@ -536,11 +528,19 @@ def MultiMixtureNormalModel(
pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1)
),
)
with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2):
with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2) as t:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.control_timepoint
time_pi = data.control_timepoint[t]
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
expanded_allele_p = pi * torch.exp(
mu.unsqueeze(0)
.unsqueeze(0)
.expand(data.n_reps, len(time_pi), -1, -1)
* time_pi.unsqueeze(0)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(data.n_reps, -1, data.n_guides, data.n_max_alleles),
)
pyro.sample(
"control_allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
Expand All @@ -558,11 +558,10 @@ def MultiMixtureNormalModel(
assert time.shape == (data.n_condits,)

with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
alleles_p_time = torch.clamp(
alleles_p_time = torch.exp(
time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1))
* torch.log(r).unsqueeze(0).expand((data.n_condits, -1, -1)),
max=MAX_LOGPI,
).exp()
* mu.unsqueeze(0).expand((data.n_condits, -1, -1)),
)

mask = data.allele_mask.unsqueeze(0).expand((data.n_condits, -1, -1))
alleles_p_time = alleles_p_time * mask
Expand All @@ -575,73 +574,56 @@ def MultiMixtureNormalModel(
expected_allele_p = (
pi.expand(data.n_reps, data.n_condits, -1, -1)
* alleles_p_time[None, :, :, :]
* q_0.unsqueeze(1).unsqueeze(-1).expand((-1, data.n_condits, -1, -1))
)
expected_guide_p = expected_allele_p.sum(axis=-1)
assert expected_guide_p.shape == (
data.n_reps,
data.n_condits,
data.n_guides,
), expected_guide_p.shape
try:
with replicate_plate2:
with pyro.plate("guide_plate3", data.n_guides, dim=-1):
a = get_alpha(
expected_guide_p, data.size_factor, data.sample_mask, data.a0

with replicate_plate2:
with pyro.plate("guide_plate3", data.n_guides, dim=-1):
a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0)
a_bcmatch = get_alpha(
expected_guide_p,
data.size_factor_bcmatch,
data.sample_mask,
data.a0_bcmatch,
)
# assert a.shape == a_bcmatch.shape == (data.n_reps, data.n_guides, data.n_condits)
assert (
data.X.shape
== data.X_bcmatch_masked.shape
== (
data.n_reps,
data.n_condits,
data.n_guides,
)
a_bcmatch = get_alpha(
expected_guide_p,
data.size_factor_bcmatch,
data.sample_mask,
data.a0_bcmatch,
)
with poutine.mask(
mask=torch.logical_and(
data.X_masked.permute(0, 2, 1).sum(axis=-1) > 10,
data.repguide_mask,
)
# assert a.shape == a_bcmatch.shape == (data.n_reps, data.n_guides, data.n_condits)
assert (
data.X.shape
== data.X_bcmatch_masked.shape
== (
data.n_reps,
data.n_condits,
data.n_guides,
)
):
pyro.sample(
"guide_counts",
dist.DirichletMultinomial(a, validate_args=False),
obs=data.X_masked.permute(0, 2, 1),
)
if use_bcmatch:
with poutine.mask(
mask=torch.logical_and(
data.X_masked.permute(0, 2, 1).sum(axis=-1) > 10,
data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) > 10,
data.repguide_mask,
)
):
pyro.sample(
"guide_counts",
dist.DirichletMultinomial(a, validate_args=False),
obs=data.X_masked.permute(0, 2, 1),
"guide_bcmatch_counts",
dist.DirichletMultinomial(a_bcmatch, validate_args=False),
obs=data.X_bcmatch_masked.permute(0, 2, 1),
)
if use_bcmatch:
with poutine.mask(
mask=torch.logical_and(
data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) > 10,
data.repguide_mask,
)
):
pyro.sample(
"guide_bcmatch_counts",
dist.DirichletMultinomial(a_bcmatch, validate_args=False),
obs=data.X_bcmatch_masked.permute(0, 2, 1),
)
except ValueError as e:
print(f"ERROR a is 0 at {torch.sum(a.sum(axis=-1) ==0)}")
print(
f"ERROR expected_guide_p is 0 at {torch.sum(expected_guide_p.sum(axis=1) ==0)}"
)
print(f"ERROR a is NaN at {torch.where(a.isnan().any(axis=-1))}")
print(
f"ERROR data.size_factor is NaN at {torch.where(data.size_factor.isnan())}"
)
print(
f"ERROR expected_guide_p is NaN at {torch.where(expected_guide_p.isnan().any(axis=1))}"
)
print(f"ERROR a0 is NaN at {torch.where(data.a0.isnan())}")
raise e


def NormalGuide(data):
Expand Down Expand Up @@ -762,17 +744,16 @@ def ControlNormalGuide(data, mask_thres=10, use_bcmatch=True):
Fit shared mean
"""
# Set the prior for phenotype means
mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets,)))
mu_loc = pyro.param("mu_loc", torch.tensor(0.0))
mu_scale = pyro.param(
"mu_scale",
torch.ones((data.n_targets,)) * 0.1,
torch.tensor(1.0),
constraint=constraints.positive,
)
with pyro.plate("target_plate", data.n_targets):
mu = pyro.sample(
"mu_targets",
dist.Normal(mu_loc, mu_scale),
)
mu = pyro.sample(
"mu_targets",
dist.Normal(mu_loc, mu_scale),
)


def MultiMixtureNormalGuide(
Expand All @@ -794,11 +775,6 @@ def MultiMixtureNormalGuide(
torch.ones(data.n_guides) / data.n_guides,
constraint=constraints.positive,
)
with pyro.plate("replicate_plate0", data.n_reps, dim=-1):
q_0 = pyro.sample(
"initial_guide_abundance",
dist.Dirichlet(initial_abundance),
)
# Set the prior for phenotype means
mu_loc = pyro.param("mu_loc", torch.zeros((data.n_edits,)))
mu_scale = pyro.param(
Expand Down

0 comments on commit 8125c26

Please sign in to comment.