Skip to content

Commit

Permalink
first attempt at passing RT initialization into model
Browse files Browse the repository at this point in the history
  • Loading branch information
adamcweiner committed Oct 3, 2023
1 parent 71a0086 commit 7822ff3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
9 changes: 5 additions & 4 deletions scdna_replication_tools/infer_scRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def get_args():

class scRT:
def __init__(self, cn_s, cn_g1, input_col='reads', assign_col='copy', library_col='library_id', ploidy_col='ploidy',
cell_col='cell_id', cn_state_col='state', chr_col='chr', start_col='start', gc_col='gc',
rv_col='rt_value', rs_col='rt_state', frac_rt_col='frac_rt', clone_col='clone_id', rt_prior_col='mcf7rt',
cn_prior_method='hmmcopy', col2='rpm_gc_norm', col3='temp_rt', col4='changepoint_segments', col5='binary_thresh',
cell_col='cell_id', cn_state_col='state', chr_col='chr', start_col='start', gc_col='gc',
frac_rt_col='frac_rt', clone_col='clone_id', rt_init_col=None, rt_prior_col=None, cn_prior_method='hmmcopy',
rv_col='rt_value', rs_col='rt_state', col2='rpm_gc_norm', col3='temp_rt', col4='changepoint_segments', col5='binary_thresh',
max_iter=2000, min_iter=100, max_iter_step1=None, min_iter_step1=None, max_iter_step3=None, min_iter_step3=None,
cn_prior_weight=1e6, learning_rate=0.05, rel_tol=1e-6, cuda=False, seed=0, P=13, K=4, upsilon=6, run_step3=True):
self.cn_s = cn_s
Expand All @@ -50,6 +50,7 @@ def __init__(self, cn_s, cn_g1, input_col='reads', assign_col='copy', library_co
self.gc_col = gc_col
self.ploidy_col = ploidy_col
self.rt_prior_col = rt_prior_col
self.rt_init_col = rt_init_col

# column representing continuous replication timing value of each bin
self.rv_col = rv_col
Expand Down Expand Up @@ -151,7 +152,7 @@ def infer_pert_model(self):

# run pyro model to get replication timing states
print('using {} as cn_prior_method'.format(self.cn_prior_method))
pert_model = pert_infer_scRT(self.cn_s, self.cn_g1, input_col=self.input_col, gc_col=self.gc_col, rt_prior_col=self.rt_prior_col,
pert_model = pert_infer_scRT(self.cn_s, self.cn_g1, input_col=self.input_col, gc_col=self.gc_col, rt_prior_col=self.rt_prior_col, rt_init_col=self.rt_init_col,
clone_col=self.clone_col, cell_col=self.cell_col, library_col=self.library_col, assign_col=self.assign_col,
chr_col=self.chr_col, start_col=self.start_col, cn_state_col=self.cn_state_col,
rs_col=self.rs_col, frac_rt_col=self.frac_rt_col, cn_prior_method=self.cn_prior_method, cn_prior_weight=self.cn_prior_weight,
Expand Down
29 changes: 20 additions & 9 deletions scdna_replication_tools/pert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


class pert_infer_scRT():
def __init__(self, cn_s, cn_g1, input_col='reads', gc_col='gc', rt_prior_col='mcf7rt',
def __init__(self, cn_s, cn_g1, input_col='reads', gc_col='gc', rt_prior_col=None, rt_init_col=None,
clone_col='clone_id', cell_col='cell_id', library_col='library_id',
chr_col='chr', start_col='start', cn_state_col='state', assign_col='copy',
rs_col='rt_state', frac_rt_col='frac_rt', cn_prior_method='g1_composite',
Expand All @@ -48,6 +48,7 @@ def __init__(self, cn_s, cn_g1, input_col='reads', gc_col='gc', rt_prior_col='mc
:param input_col: column containing read count input. (str)
:param gc_col: column for gc content of each bin. (str)
:param rt_prior_col: column RepliSeq-determined replication timing values to be used as a prior. (str)
:param rt_init_col: column RepliSeq-determined replication timing values to be used as an initialisation. (str)
:param clone_col: column for clone ID of each cell. (str)
:param cell_col: column for cell ID of each cell. (str)
:param library_col: column for library ID of each cell. (str)
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, cn_s, cn_g1, input_col='reads', gc_col='gc', rt_prior_col='mc
self.input_col = input_col
self.gc_col = gc_col
self.rt_prior_col = rt_prior_col
self.rt_init_col = rt_init_col
self.clone_col = clone_col
self.cell_col = cell_col
self.library_col = library_col
Expand Down Expand Up @@ -532,7 +534,7 @@ def package_s_output(self, cn_s, trace_s, cn_s_reads_df, lambda_fit, losses_g, l


@config_enumerate
def model_s(self, gammas, libs, cn0=None, rho0=None, num_cells=None, num_loci=None, data=None, etas=None, lamb=None, lambda_init=1e-1, t_alpha_prior=None, t_beta_prior=None, t_init=None):
def model_s(self, gammas, libs, cn0=None, rho0=None, rho_init=None, num_cells=None, num_loci=None, data=None, etas=None, lamb=None, lambda_init=1e-1, t_alpha_prior=None, t_beta_prior=None, t_init=None):
with ignore_jit_warnings():
if data is not None:
num_loci, num_cells = data.shape
Expand Down Expand Up @@ -564,7 +566,10 @@ def model_s(self, gammas, libs, cn0=None, rho0=None, num_cells=None, num_loci=No
else:
with loci_plate:
# bulk replication timing profile
rho = pyro.sample('expose_rho', dist.Beta(torch.tensor([1.]), torch.tensor([1.])))
if rho_init is not None:
rho = pyro.param('expose_rho', rho_init, constraint=constraints.unit_interval)
else:
rho = pyro.sample('expose_rho', dist.Beta(torch.tensor([1.]), torch.tensor([1.])))

with cell_plate:

Expand Down Expand Up @@ -771,6 +776,15 @@ def run_pert_model(self):
pyro.clear_param_store()
pyro.enable_validation(False)

# use manhattan binarization method to come up with an initial guess for each cell's time in S-phase
t_init, t_alpha_prior, t_beta_prior = self.guess_times(cn_s_reads, etas)

# if the user has provided an initialization for RT, extract that as a tensor
if self.rt_init_col is not None:
rho_init = torch.tensor(cn_s_reads_df[self.rt_init_col].values)
else:
rho_init = None

# condition gc betas of S-phase model using fitted results from G1-phase model
model_s = poutine.condition(
model_s,
Expand All @@ -779,9 +793,6 @@ def run_pert_model(self):
'expose_beta_stds': beta_stds_fit,
})

# use manhattan binarization method to come up with an initial guess for each cell's time in S-phase
t_init, t_alpha_prior, t_beta_prior = self.guess_times(cn_s_reads, etas)

guide_s = AutoDelta(poutine.block(model_s, expose_fn=lambda msg: msg["name"].startswith("expose_")))
optim_s = pyro.optim.Adam({'lr': self.learning_rate, 'betas': [0.8, 0.99]})
elbo_s = JitTraceEnum_ELBO(max_plate_nesting=2)
Expand All @@ -791,7 +802,7 @@ def run_pert_model(self):
logging.info('STEP 2: Jointly infer replication and CN states in high variance cells.')
losses_s = []
for i in range(self.max_iter):
loss = svi_s.step(gammas, libs_s, data=cn_s_reads, etas=etas, lamb=lambda_fit, t_init=t_init)
loss = svi_s.step(gammas, libs_s, rho_init=rho_init, data=cn_s_reads, etas=etas, lamb=lambda_fit, t_init=t_init)

losses_s.append(loss)
logging.info('step: {}, loss: {}'.format(i, loss))
Expand All @@ -810,14 +821,14 @@ def run_pert_model(self):


# replay model
guide_trace_s = poutine.trace(guide_s).get_trace(gammas, libs_s, data=cn_s_reads, etas=etas, lamb=lambda_fit, t_init=t_init)
guide_trace_s = poutine.trace(guide_s).get_trace(gammas, libs_s, rho_init=rho_init, data=cn_s_reads, etas=etas, lamb=lambda_fit, t_init=t_init)
trained_model_s = poutine.replay(model_s, trace=guide_trace_s)

# infer discrete sites and get model trace
inferred_model_s = infer_discrete(
trained_model_s, temperature=0,
first_available_dim=-3)
trace_s = poutine.trace(inferred_model_s).get_trace(gammas, libs_s, data=cn_s_reads, etas=etas, lamb=lambda_fit, t_init=t_init)
trace_s = poutine.trace(inferred_model_s).get_trace(gammas, libs_s, rho_init=rho_init, data=cn_s_reads, etas=etas, lamb=lambda_fit, t_init=t_init)

# get output dataframes based on learned latent parameters and states
cn_s_out, supp_s_out_df = self.package_s_output(self.cn_s, trace_s, cn_s_reads_df, lambda_fit, losses_g, losses_s)
Expand Down

0 comments on commit 7822ff3

Please sign in to comment.