diff --git a/scdna_replication_tools/pert_model.py b/scdna_replication_tools/pert_model.py index 05d9bc7..e3264bf 100644 --- a/scdna_replication_tools/pert_model.py +++ b/scdna_replication_tools/pert_model.py @@ -187,8 +187,18 @@ def process_input_data(self): assert cn_s_reads.shape[0] == gammas.shape[0] == rt_prior_profile.shape[0] else: rt_prior_profile = None + + # get tensor for rt initialization if provided + if (self.rt_init_col is not None) and (self.rt_init_col in self.cn_s.columns): + rt_init_profile = self.cn_s[[self.chr_col, self.start_col, self.rt_init_col]].drop_duplicates() + rt_init_profile = rt_init_profile.dropna() + rt_init_profile = torch.tensor(rt_init_profile[self.rt_init_col].values).unsqueeze(-1).to(torch.float32) + rt_init_profile = self.convert_rt_prior_units(rt_init_profile) + assert cn_s_reads.shape[0] == gammas.shape[0] == rt_init_profile.shape[0] + else: + rt_init_profile = None - return cn_g1_reads_df, cn_g1_states_df, cn_s_reads_df, cn_s_states_df, cn_g1_reads, cn_g1_states, cn_s_reads, cn_s_states, gammas, rt_prior_profile, libs_g1, libs_s + return cn_g1_reads_df, cn_g1_states_df, cn_s_reads_df, cn_s_states_df, cn_g1_reads, cn_g1_states, cn_s_reads, cn_s_states, gammas, rt_prior_profile, rt_init_profile, libs_g1, libs_s def sort_by_cell_and_loci(self, cn): @@ -655,7 +665,7 @@ def run_pert_model(self): cn_g1_reads_df, cn_g1_states_df, cn_s_reads_df, cn_s_states_df, \ cn_g1_reads, cn_g1_states, cn_s_reads, cn_s_states, \ - gammas, rt_prior_profile, libs_g1, libs_s = self.process_input_data() + gammas, rt_prior_profile, rho_init, libs_g1, libs_s = self.process_input_data() # compute consensus clone profiles for cn state clone_cn_profiles = compute_consensus_clone_profiles( @@ -779,12 +789,6 @@ def run_pert_model(self): # use manhattan binarization method to come up with an initial guess for each cell's time in S-phase t_init, t_alpha_prior, t_beta_prior = self.guess_times(cn_s_reads, etas) - # if the user has provided an initialization for RT, extract that as a tensor - if self.rt_init_col is not None: - rho_init = torch.tensor(cn_s_reads_df[self.rt_init_col].values) - else: - rho_init = None - # condition gc betas of S-phase model using fitted results from G1-phase model model_s = poutine.condition( model_s,