From d538db92dc09daf78f23c2b36f2329c7dd287df9 Mon Sep 17 00:00:00 2001 From: jykr Date: Tue, 9 Apr 2024 11:12:57 -0400 Subject: [PATCH] allow pi to be sampled from selected samples --- bean/cli/run.py | 13 +++++++------ bean/model/parser.py | 8 ++++---- bean/model/readwrite.py | 4 ++-- bean/model/survival_model.py | 12 ++++++++---- bean/preprocessing/data_class.py | 10 ++++++++++ 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/bean/cli/run.py b/bean/cli/run.py index 3d040da..705e275 100755 --- a/bean/cli/run.py +++ b/bean/cli/run.py @@ -90,6 +90,8 @@ def main(args): model_label, model, guide = identify_model_guide(args) info("Done loading data. Preprocessing...") + + # Format bdata into data structure compatible with Pyro model bdata = prepare_bdata(bdata, args, warn, prefix) guide_index = bdata.guides.index.copy() ndata = DATACLASS_DICT[args.selection][model_label]( @@ -103,7 +105,7 @@ def main(args): condition_column=args.condition_col, time_column=args.time_col, control_condition=args.control_condition, - control_can_be_selected=args.include_control_condition_for_inference, + control_can_be_selected=~args.exclude_control_condition_for_inference, allele_df_key=args.allele_df_key, control_guide_tag=args.control_guide_tag, target_col=args.target_col, @@ -112,6 +114,8 @@ def main(args): replicate_col=args.replicate_col, use_bcmatch=(not args.ignore_bcmatch), ) + + # Build variant dataframe adj_negctrl_idx = None if args.library_design == "variant": if not args.uniform_edit: @@ -151,8 +155,8 @@ def main(args): guide_info_df = ndata.screen.guides + # Run the inference steps info(f"Running inference for {model_label}...") - if args.load_existing: with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle: param_history_dict = pkl.load(handle) @@ -187,11 +191,8 @@ def main(args): if not os.path.exists(prefix): os.makedirs(prefix) with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle: - # try: pkl.dump(save_dict, handle) - # except TypeError as exc: - # print(exc.message) - # print(param_history_dict) + write_result_table( target_info_df, param_history_dict, diff --git a/bean/model/parser.py b/bean/model/parser.py index db79c14..91b53cc 100755 --- a/bean/model/parser.py +++ b/bean/model/parser.py @@ -83,14 +83,14 @@ def parse_args(parser=None): "--control-condition", default="bulk", type=str, - help="Value in `bdata.samples[condition_col]` that indicates control experimental condition.", + help="Value in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used.", ) parser.add_argument( - "--include-control-condition-for-inference", - "-ic", + "--exclude-control-condition-for-inference", + "-ec", default=False, action="store_true", - help="Include control conditions for inference. Currently only supported for survival screens.", + help="Exclude control conditions for inference. Currently only supported for survival screens.", ) parser.add_argument( "--replicate-col", diff --git a/bean/model/readwrite.py b/bean/model/readwrite.py index c7cf181..19f3f9a 100755 --- a/bean/model/readwrite.py +++ b/bean/model/readwrite.py @@ -55,11 +55,11 @@ def write_result_table( negctrl_params=None, write_fitted_eff: bool = True, adjust_confidence_by_negative_control: bool = True, - adjust_confidence_negatives: np.ndarray = None, + adjust_confidence_negatives: Optional[np.ndarray] = None, guide_index: Optional[Sequence[str]] = None, guide_acc: Optional[Sequence] = None, sd_is_fitted: bool = True, - sample_covariates: List[str] = None, + sample_covariates: Optional[List[str]] = None, return_result: bool = False, ) -> Union[pd.DataFrame, None]: """Combine target information and scores to write result table to a csv file or return it.""" diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 88d7d10..d2c3d7b 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -229,7 +229,6 @@ def MixtureNormalModel( "initial_guide_abundance", dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))), ) - # The pi should be Dirichlet distributed instead of independent betas alpha_pi = pyro.param( "alpha_pi", torch.ones( @@ -248,6 +247,7 @@ def MixtureNormalModel( ), alpha_pi.shape with replicate_plate: with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): + time_pi = data.control_timepoint # Accounting for sample specific overall edit rate across all guides. # P(allele | guide, bin=bulk) pi = pyro.sample( @@ -262,9 +262,11 @@ def MixtureNormalModel( data.n_guides, 2, ), pi.shape + # If pi is sampled in later timepoint, account for the selection. + expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi pyro.sample( "bulk_allele_count", - dist.Multinomial(probs=pi, validate_args=False), + dist.Multinomial(probs=expanded_allele_p, validate_args=False), obs=data.allele_counts_control, ) if scale_by_accessibility: @@ -277,7 +279,6 @@ def MixtureNormalModel( time = data.timepoints[t] assert time.shape == (data.n_condits,) - # with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)): with guide_plate: alleles_p_time = r.unsqueeze(0).expand( (data.n_condits, -1, -1) @@ -502,15 +503,18 @@ def MultiMixtureNormalModel( print(torch.where(alpha_pi < 0)) with replicate_plate: with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)): + time_pi = data.control_timepoint pi = pyro.sample( "pi", dist.Dirichlet( pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1) ), ) + # If pi is sampled in later timepoint, account for the selection. + expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi pyro.sample( "bulk_allele_count", - dist.Multinomial(probs=pi, validate_args=False), + dist.Multinomial(probs=expanded_allele_p, validate_args=False), obs=data.allele_counts_control, ) if scale_by_accessibility: diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 0c37fb2..5c1f430 100755 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -999,6 +999,16 @@ def _post_init( self.timepoints = torch.as_tensor( self.screen_selected.samples[self.time_column].unique() ) + control_timepoint = self.screen_control.samples[self.time_column].unique() + if len(control_timepoint) != 1: + info(self.screen_control) + info(self.screen_control.samples) + info(control_timepoint) + raise ValueError( + "All samples with --control-condition should have the same --time-col column in ReporterScreen.samples[time_col]. Check your input ReporterScreen object." + ) + else: + self.control_timepoint = control_timepoint[0] self.n_timepoints = self.n_condits timepoints = self.screen_selected.samples.sort_values(self.time_column)[ self.time_column