From 21db8f48aef1aa8a0ca70dfc6e5daf3e3aa14ef7 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 27 Aug 2024 23:57:17 +0000 Subject: [PATCH] twoctrls10 --- bean/model/survival_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bean/model/survival_model.py b/bean/model/survival_model.py index 228124d..90eafbd 100755 --- a/bean/model/survival_model.py +++ b/bean/model/survival_model.py @@ -328,8 +328,14 @@ def MixtureNormalModel( # else: time_pi = data.control_timepoint[t] # If pi is sampled in later timepoint, account for the selection. - expanded_allele_p = pi * torch.pow( - r.expand(data.n_reps, 1, -1, -1), time_pi + expanded_allele_p = pi.expand(-1, len(time_pi), -1, -1) * torch.pow( + r.unsqueeze(0) + .unsqueeze(0) + .expand(data.n_reps, len(time_pi), -1, -1), + time_pi.unsqueeze(0) + .unsqueeze(-1) + .unsqueeze(-1) + .expand(data.n_reps, -1, data.n_guides, 2), ) pyro.sample( "control_allele_count",