diff --git a/bean/cli/run.py b/bean/cli/run.py index 80520e9..0bf916d 100755 --- a/bean/cli/run.py +++ b/bean/cli/run.py @@ -129,14 +129,14 @@ def main(args, return_data=False): return ndata # Build variant dataframe adj_negctrl_idx = None - + _control_condition = args.control_condition.split(",")[0] if args.library_design == "variant": if not args.uniform_edit: if "edit_rate" not in ndata.screen.guides.columns: ndata.screen.get_edit_from_allele() ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True) ndata.screen.get_guide_edit_rate( - unsorted_condition_label=args.control_condition + unsorted_condition_label=_control_condition ) target_info_df = _get_guide_target_info( ndata.screen, args, cols_include=[args.negctrl_col] @@ -151,7 +151,7 @@ def main(args, return_data=False): ndata.screen.get_edit_from_allele() ndata.screen.get_edit_mat_from_uns(rel_pos_is_reporter=True) ndata.screen.get_guide_edit_rate( - unsorted_condition_label=args.control_condition + unsorted_condition_label=_control_condition ) if args.splice_site_path is not None: splice_site = pd.read_csv(args.splice_site_path).pos diff --git a/bean/mapping/GuideEditCounter.py b/bean/mapping/GuideEditCounter.py index 72dae7b..99acec2 100755 --- a/bean/mapping/GuideEditCounter.py +++ b/bean/mapping/GuideEditCounter.py @@ -11,6 +11,7 @@ from bean import Allele, ReporterScreen from Bio import SeqIO from Bio.SeqIO.QualityIO import FastqPhredIterator + if sys.stderr.isatty(): # Output into terminal from tqdm import tqdm @@ -19,6 +20,7 @@ def tqdm(iterable, **kwargs): return iterable + from ._supporting_fn import ( _base_edit_to_from, _get_edited_allele_crispresso, @@ -119,7 +121,11 @@ def __init__(self, **kwargs): ) self.screen.guides["guide_len"] = self.screen.guides.sequence.map(len) self.screen.uns["reporter_length"] = kwargs["reporter_length"] - self.screen.uns["reporter_right_flank_length"] = kwargs["reporter_length"] - kwargs["gstart_reporter"] - self.screen.guides["guide_len"].max() + self.screen.uns["reporter_right_flank_length"] = ( + kwargs["reporter_length"] + - kwargs["gstart_reporter"] + - self.screen.guides["guide_len"].max() + ) self.count_guide_edits = kwargs["count_guide_edits"] if self.count_guide_edits: self.screen.uns["guide_edit_counts"] = {} @@ -387,8 +393,8 @@ def _count_guide_edits( R1_record, len(ref_guide_seq) ) guide_edit_allele, score = _get_edited_allele_crispresso( - ref_seq=ref_guide_seq, - query_seq=read_guide_seq, + ref_seq=ref_guide_seq.upper(), + query_seq=read_guide_seq.upper(), target_base_edits=self.target_base_edits, aln_mat_path=self.output_dir + "/.aln_mat.txt", offset=0, @@ -506,8 +512,8 @@ def _count_reporter_edits( else: chrom = None allele, score = _get_edited_allele_crispresso( - ref_seq=ref_reporter_seq, - query_seq=read_reporter_seq, + ref_seq=ref_reporter_seq.upper(), + query_seq=read_reporter_seq.upper(), target_base_edits=self.target_base_edits, aln_mat_path=self.output_dir + "/.aln_mat.txt", offset=offset, @@ -547,7 +553,7 @@ def _get_guide_counts_bcmatch_semimatch( "duplicate_wo_barcode" ) outfile_R1_dup, outfile_R2_dup = self._get_fastq_handle("duplicate") - tqdm_reads= tqdm( + tqdm_reads = tqdm( enumerate(zip(R1_iter, R2_iter)), total=self.n_reads_after_filtering, postfix=f"n_read={self.bcmatch}", @@ -578,9 +584,7 @@ def _get_guide_counts_bcmatch_semimatch( matched_guide_idx = semimatch[0] self.screen.layers[semimatch_layer][matched_guide_idx, 0] += 1 if self.count_guide_edits: - guide_allele, _ = self._count_guide_edits( - matched_guide_idx, r1 - ) + guide_allele, _ = self._count_guide_edits(matched_guide_idx, r1) self.semimatch += 1 elif len(bc_match) >= 2: # duplicate mapping diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 1490cc9..64b22ac 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -138,11 +138,9 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): replicate_plate2 = pyro.plate("rep_plate2", data.n_reps, dim=-2) time_plate = pyro.plate("time_plate", data.n_condits, dim=-2) guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) - # Set the prior for phenotype means - # print(f"why? {data.n_targets}, {data.target_lengths.shape}") - with pyro.plate("target_plate", data.n_targets): - mu_targets = pyro.sample("mu_targets", dist.Normal(0, 1)) - mu = torch.repeat_interleave(mu_targets, data.target_lengths) + + mu_targets = pyro.sample("mu_targets", dist.Normal(0, 1)) + mu = mu_targets.repeat(data.n_guides) with replicate_plate: with time_plate as t: time = data.timepoints[t] @@ -361,14 +359,6 @@ def MixtureNormalModel( mu.unsqueeze(0).expand((data.n_condits, -1, -1)) * time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)), ) - # negctrl_abundance = pyro.param( - # "negctrl_abundance", - # torch.ones((data.n_condits,)), - # constraint=constraints.positive, - # ) - # alleles_p_time = ( - # alleles_p_time / negctrl_abundance.clamp(min=1e-5)[:, None, None] - # ) assert alleles_p_time.shape == (data.n_condits, data.n_guides, 2) expected_allele_p = ( @@ -445,6 +435,7 @@ def MultiMixtureNormalModel( fit_noise: bool = False, prior_params: Optional[dict] = None, epsilon=1e-5, + mu_negctrl: float = (0.0, 0.1), ): """ Using the reporter outcome, phenotype of cells with a guide will be modeled as mixture of normal distributions of all major alleles (including WT) produced by the guide. @@ -467,7 +458,7 @@ def MultiMixtureNormalModel( guide_plate = pyro.plate("guide_plate", data.n_guides, dim=-1) mu_dist = dist.Laplace(0, 1) - initial_abundance = torch.ones(data.n_guides) / data.n_guides + # initial_abundance = torch.ones(data.n_guides) / data.n_guides if prior_params is not None: if "mu_loc" in prior_params or "mu_scale" in prior_params: mu_loc = 0.0 @@ -477,8 +468,8 @@ def MultiMixtureNormalModel( if "mu_scale" in prior_params: mu_scale = prior_params["mu_scale"] mu_dist = dist.Normal(mu_loc, mu_scale) - if "initial_abundance" in prior_params: - initial_abundance = prior_params["initial_abundance"] + # if "initial_abundance" in prior_params: + # initial_abundance = prior_params["initial_abundance"] # Set the prior for phenotype means with pyro.plate("guide_plate1", data.n_edits): @@ -489,17 +480,18 @@ def MultiMixtureNormalModel( data.n_max_alleles - 1, data.n_edits, ) + mu_targets = torch.matmul(data.allele_to_edit, mu_edits) assert mu_targets.shape == (data.n_guides, data.n_max_alleles - 1) - - mu = torch.cat([torch.zeros((data.n_guides, 1)), mu_targets], axis=-1) - r = torch.exp(mu) - - with pyro.plate("replicate_plate0", data.n_reps, dim=-1): - q_0 = pyro.sample( - "initial_guide_abundance", - dist.Dirichlet(initial_abundance.unsqueeze(0).expand(data.n_reps, -1)), + with pyro.plate("guide_plate_0", data.n_guides): + mu_guide_unedited = pyro.sample( + "mu_negctrl", dist.Normal(mu_negctrl[0], mu_negctrl[1]) ) + mu = torch.cat( + [mu_guide_unedited.unsqueeze(-1), mu_guide_unedited.unsqueeze(-1) + mu_targets], + axis=1, + ) + # The pi should be Dirichlet distributed instead of independent betas alpha_pi0 = ( torch.ones( @@ -536,11 +528,19 @@ def MultiMixtureNormalModel( pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1) ), ) - with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2): + with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2) as t: with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): - time_pi = data.control_timepoint + time_pi = data.control_timepoint[t] # If pi is sampled in later timepoint, account for the selection. - expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi + expanded_allele_p = pi * torch.exp( + mu.unsqueeze(0) + .unsqueeze(0) + .expand(data.n_reps, len(time_pi), -1, -1) + * time_pi.unsqueeze(0) + .unsqueeze(-1) + .unsqueeze(-1) + .expand(data.n_reps, -1, data.n_guides, data.n_max_alleles), + ) pyro.sample( "control_allele_count", dist.Multinomial(probs=expanded_allele_p, validate_args=False), @@ -558,11 +558,10 @@ def MultiMixtureNormalModel( assert time.shape == (data.n_condits,) with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): - alleles_p_time = torch.clamp( + alleles_p_time = torch.exp( time.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)) - * torch.log(r).unsqueeze(0).expand((data.n_condits, -1, -1)), - max=MAX_LOGPI, - ).exp() + * mu.unsqueeze(0).expand((data.n_condits, -1, -1)), + ) mask = data.allele_mask.unsqueeze(0).expand((data.n_condits, -1, -1)) alleles_p_time = alleles_p_time * mask @@ -575,7 +574,6 @@ def MultiMixtureNormalModel( expected_allele_p = ( pi.expand(data.n_reps, data.n_condits, -1, -1) * alleles_p_time[None, :, :, :] - * q_0.unsqueeze(1).unsqueeze(-1).expand((-1, data.n_condits, -1, -1)) ) expected_guide_p = expected_allele_p.sum(axis=-1) assert expected_guide_p.shape == ( @@ -583,65 +581,49 @@ def MultiMixtureNormalModel( data.n_condits, data.n_guides, ), expected_guide_p.shape - try: - with replicate_plate2: - with pyro.plate("guide_plate3", data.n_guides, dim=-1): - a = get_alpha( - expected_guide_p, data.size_factor, data.sample_mask, data.a0 + + with replicate_plate2: + with pyro.plate("guide_plate3", data.n_guides, dim=-1): + a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0) + a_bcmatch = get_alpha( + expected_guide_p, + data.size_factor_bcmatch, + data.sample_mask, + data.a0_bcmatch, + ) + # assert a.shape == a_bcmatch.shape == (data.n_reps, data.n_guides, data.n_condits) + assert ( + data.X.shape + == data.X_bcmatch_masked.shape + == ( + data.n_reps, + data.n_condits, + data.n_guides, ) - a_bcmatch = get_alpha( - expected_guide_p, - data.size_factor_bcmatch, - data.sample_mask, - data.a0_bcmatch, + ) + with poutine.mask( + mask=torch.logical_and( + data.X_masked.permute(0, 2, 1).sum(axis=-1) > 10, + data.repguide_mask, ) - # assert a.shape == a_bcmatch.shape == (data.n_reps, data.n_guides, data.n_condits) - assert ( - data.X.shape - == data.X_bcmatch_masked.shape - == ( - data.n_reps, - data.n_condits, - data.n_guides, - ) + ): + pyro.sample( + "guide_counts", + dist.DirichletMultinomial(a, validate_args=False), + obs=data.X_masked.permute(0, 2, 1), ) + if use_bcmatch: with poutine.mask( mask=torch.logical_and( - data.X_masked.permute(0, 2, 1).sum(axis=-1) > 10, + data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) > 10, data.repguide_mask, ) ): pyro.sample( - "guide_counts", - dist.DirichletMultinomial(a, validate_args=False), - obs=data.X_masked.permute(0, 2, 1), + "guide_bcmatch_counts", + dist.DirichletMultinomial(a_bcmatch, validate_args=False), + obs=data.X_bcmatch_masked.permute(0, 2, 1), ) - if use_bcmatch: - with poutine.mask( - mask=torch.logical_and( - data.X_bcmatch_masked.permute(0, 2, 1).sum(axis=-1) > 10, - data.repguide_mask, - ) - ): - pyro.sample( - "guide_bcmatch_counts", - dist.DirichletMultinomial(a_bcmatch, validate_args=False), - obs=data.X_bcmatch_masked.permute(0, 2, 1), - ) - except ValueError as e: - print(f"ERROR a is 0 at {torch.sum(a.sum(axis=-1) ==0)}") - print( - f"ERROR expected_guide_p is 0 at {torch.sum(expected_guide_p.sum(axis=1) ==0)}" - ) - print(f"ERROR a is NaN at {torch.where(a.isnan().any(axis=-1))}") - print( - f"ERROR data.size_factor is NaN at {torch.where(data.size_factor.isnan())}" - ) - print( - f"ERROR expected_guide_p is NaN at {torch.where(expected_guide_p.isnan().any(axis=1))}" - ) - print(f"ERROR a0 is NaN at {torch.where(data.a0.isnan())}") - raise e def NormalGuide(data): @@ -762,17 +744,16 @@ def ControlNormalGuide(data, mask_thres=10, use_bcmatch=True): Fit shared mean """ # Set the prior for phenotype means - mu_loc = pyro.param("mu_loc", torch.zeros((data.n_targets,))) + mu_loc = pyro.param("mu_loc", torch.tensor(0.0)) mu_scale = pyro.param( "mu_scale", - torch.ones((data.n_targets,)) * 0.1, + torch.tensor(1.0), constraint=constraints.positive, ) - with pyro.plate("target_plate", data.n_targets): - mu = pyro.sample( - "mu_targets", - dist.Normal(mu_loc, mu_scale), - ) + mu = pyro.sample( + "mu_targets", + dist.Normal(mu_loc, mu_scale), + ) def MultiMixtureNormalGuide( @@ -794,11 +775,6 @@ def MultiMixtureNormalGuide( torch.ones(data.n_guides) / data.n_guides, constraint=constraints.positive, ) - with pyro.plate("replicate_plate0", data.n_reps, dim=-1): - q_0 = pyro.sample( - "initial_guide_abundance", - dist.Dirichlet(initial_abundance), - ) # Set the prior for phenotype means mu_loc = pyro.param("mu_loc", torch.zeros((data.n_edits,))) mu_scale = pyro.param(