Skip to content

Commit

Permalink
debug survival negctrl model
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Mar 29, 2024
1 parent 368ba55 commit a93bc83
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 26 deletions.
31 changes: 18 additions & 13 deletions bean/model/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def write_result_table(
)

fit_df = pd.DataFrame(param_dict)
fit_df["novl"] = get_novl(fit_df, "mu", "mu_sd")
if "negctrl" in param_hist_dict.keys():
print("Normalizing with common negative control distribution")
mu0 = param_hist_dict["negctrl"]["params"]["mu_loc"].detach().cpu().numpy()
Expand All @@ -114,6 +113,8 @@ def write_result_table(
.cpu()
.numpy()
)
else:
sd0 = 1.0
print(f"Fitted mu0={mu0}" + (f", sd0={sd0}." if sd_is_fitted else ""))
fit_df["mu_scaled"] = (mu - mu0) / sd0
fit_df["mu_sd_scaled"] = mu_sd / sd0
Expand Down Expand Up @@ -154,12 +155,12 @@ def write_result_table(
fit_df,
std,
suffix="_adj",
mu_adjusted_col="mu_scaled"
if "negctrl" in param_hist_dict.keys()
else "mu",
mu_sd_adjusted_col="mu_sd_scaled"
if "negctrl" in param_hist_dict.keys()
else "mu_sd",
mu_adjusted_col=(
"mu_scaled" if "negctrl" in param_hist_dict.keys() else "mu"
),
mu_sd_adjusted_col=(
"mu_sd_scaled" if "negctrl" in param_hist_dict.keys() else "mu_sd"
),
)
fit_df = add_credible_interval(fit_df, "mu_adj", "mu_sd_adj")
if sample_covariates is not None:
Expand All @@ -168,12 +169,16 @@ def write_result_table(
fit_df,
std,
suffix=f"_{sample_cov}_adj",
mu_adjusted_col=f"mu_{sample_cov}_scaled"
if "negctrl" in param_hist_dict.keys()
else f"mu_{sample_cov}",
mu_sd_adjusted_col=f"mu_sd_{sample_cov}_scaled"
if "negctrl" in param_hist_dict.keys()
else f"mu_sd_{sample_cov}",
mu_adjusted_col=(
f"mu_{sample_cov}_scaled"
if "negctrl" in param_hist_dict.keys()
else f"mu_{sample_cov}"
),
mu_sd_adjusted_col=(
f"mu_sd_{sample_cov}_scaled"
if "negctrl" in param_hist_dict.keys()
else f"mu_sd_{sample_cov}"
),
)
fit_df = add_credible_interval(
fit_df, f"mu_{sample_cov}_adj", f"mu_sd_{sample_cov}_adj"
Expand Down
17 changes: 17 additions & 0 deletions bean/model/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,20 @@ def identify_model_guide(args):
fit_noise=(not args.dont_fit_noise),
),
)


def identify_negctrl_model_guide(args, data_has_bcmatch):
if args.selection == "sorting":
m = sorting_model
else:
m = survival_model
negctrl_model = partial(
m.ControlNormalModel,
use_bcmatch=(not args.ignore_bcmatch and data_has_bcmatch),
)

negctrl_guide = partial(
m.ControlNormalGuide,
use_bcmatch=(not args.ignore_bcmatch and data_has_bcmatch),
)
return negctrl_model, negctrl_guide
10 changes: 5 additions & 5 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True):
mu_alleles = pyro.sample("mu_alleles", dist.Laplace(0, 1))
mu = mu_alleles.repeat(data.n_guides).unsqueeze(-1)
r = torch.exp(mu)
with pyro.plate("rep_plate1", data.n_reps, dim=-1):
q_0 = pyro.sample(
"initial_guide_abundance",
dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))),
)
with replicate_plate:
with pyro.plate("guide_plate2", data.n_guides):
q_0 = pyro.sample(
"initial_guide_abundance",
dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))),
)
with time_plate as t:
time = data.timepoints[t]
assert time.shape == (data.n_condits,)
Expand Down
11 changes: 3 additions & 8 deletions bin/bean-run
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ from bean.model.run import (
parse_args,
check_args,
identify_model_guide,
identify_negctrl_model_guide,
)

logging.basicConfig(
Expand Down Expand Up @@ -144,14 +145,8 @@ def main(args, bdata):
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
if args.fit_negctrl:
negctrl_model = partial(
m.ControlNormalModel,
use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers),
)

negctrl_guide = partial(
m.ControlNormalGuide,
use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers),
negctrl_model, negctrl_guide = identify_negctrl_model_guide(
args, "X_bcmatch" in bdata.layers
)
negctrl_idx = np.where(
guide_info_df[args.negctrl_col].map(lambda s: s.lower())
Expand Down

0 comments on commit a93bc83

Please sign in to comment.