Skip to content

Commit

Permalink
allow pi to be sampled from selected samples
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Apr 9, 2024
1 parent c599cf4 commit d538db9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
13 changes: 7 additions & 6 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def main(args):

model_label, model, guide = identify_model_guide(args)
info("Done loading data. Preprocessing...")

# Format bdata into data structure compatible with Pyro model
bdata = prepare_bdata(bdata, args, warn, prefix)
guide_index = bdata.guides.index.copy()
ndata = DATACLASS_DICT[args.selection][model_label](
Expand All @@ -103,7 +105,7 @@ def main(args):
condition_column=args.condition_col,
time_column=args.time_col,
control_condition=args.control_condition,
control_can_be_selected=args.include_control_condition_for_inference,
control_can_be_selected=~args.exclude_control_condition_for_inference,
allele_df_key=args.allele_df_key,
control_guide_tag=args.control_guide_tag,
target_col=args.target_col,
Expand All @@ -112,6 +114,8 @@ def main(args):
replicate_col=args.replicate_col,
use_bcmatch=(not args.ignore_bcmatch),
)

# Build variant dataframe
adj_negctrl_idx = None
if args.library_design == "variant":
if not args.uniform_edit:
Expand Down Expand Up @@ -151,8 +155,8 @@ def main(args):

guide_info_df = ndata.screen.guides

# Run the inference steps
info(f"Running inference for {model_label}...")

if args.load_existing:
with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle:
param_history_dict = pkl.load(handle)
Expand Down Expand Up @@ -187,11 +191,8 @@ def main(args):
if not os.path.exists(prefix):
os.makedirs(prefix)
with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle:
# try:
pkl.dump(save_dict, handle)
# except TypeError as exc:
# print(exc.message)
# print(param_history_dict)

write_result_table(
target_info_df,
param_history_dict,
Expand Down
8 changes: 4 additions & 4 deletions bean/model/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def parse_args(parser=None):
"--control-condition",
default="bulk",
type=str,
help="Value in `bdata.samples[condition_col]` that indicates control experimental condition.",
help="Value in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used.",
)
parser.add_argument(
"--include-control-condition-for-inference",
"-ic",
"--exclude-control-condition-for-inference",
"-ec",
default=False,
action="store_true",
help="Include control conditions for inference. Currently only supported for survival screens.",
help="Exclude control conditions for inference. Currently only supported for survival screens.",
)
parser.add_argument(
"--replicate-col",
Expand Down
4 changes: 2 additions & 2 deletions bean/model/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def write_result_table(
negctrl_params=None,
write_fitted_eff: bool = True,
adjust_confidence_by_negative_control: bool = True,
adjust_confidence_negatives: np.ndarray = None,
adjust_confidence_negatives: Optional[np.ndarray] = None,
guide_index: Optional[Sequence[str]] = None,
guide_acc: Optional[Sequence] = None,
sd_is_fitted: bool = True,
sample_covariates: List[str] = None,
sample_covariates: Optional[List[str]] = None,
return_result: bool = False,
) -> Union[pd.DataFrame, None]:
"""Combine target information and scores to write result table to a csv file or return it."""
Expand Down
12 changes: 8 additions & 4 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def MixtureNormalModel(
"initial_guide_abundance",
dist.Dirichlet(torch.ones((data.n_reps, data.n_guides))),
)
# The pi should be Dirichlet distributed instead of independent betas
alpha_pi = pyro.param(
"alpha_pi",
torch.ones(
Expand All @@ -248,6 +247,7 @@ def MixtureNormalModel(
), alpha_pi.shape
with replicate_plate:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.control_timepoint
# Accounting for sample specific overall edit rate across all guides.
# P(allele | guide, bin=bulk)
pi = pyro.sample(
Expand All @@ -262,9 +262,11 @@ def MixtureNormalModel(
data.n_guides,
2,
), pi.shape
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
pyro.sample(
"bulk_allele_count",
dist.Multinomial(probs=pi, validate_args=False),
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts_control,
)
if scale_by_accessibility:
Expand All @@ -277,7 +279,6 @@ def MixtureNormalModel(
time = data.timepoints[t]
assert time.shape == (data.n_condits,)

# with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)):
with guide_plate:
alleles_p_time = r.unsqueeze(0).expand(
(data.n_condits, -1, -1)
Expand Down Expand Up @@ -502,15 +503,18 @@ def MultiMixtureNormalModel(
print(torch.where(alpha_pi < 0))
with replicate_plate:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.control_timepoint
pi = pyro.sample(
"pi",
dist.Dirichlet(
pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1)
),
)
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
pyro.sample(
"bulk_allele_count",
dist.Multinomial(probs=pi, validate_args=False),
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts_control,
)
if scale_by_accessibility:
Expand Down
10 changes: 10 additions & 0 deletions bean/preprocessing/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,16 @@ def _post_init(
self.timepoints = torch.as_tensor(
self.screen_selected.samples[self.time_column].unique()
)
control_timepoint = self.screen_control.samples[self.time_column].unique()
if len(control_timepoint) != 1:
info(self.screen_control)
info(self.screen_control.samples)
info(control_timepoint)
raise ValueError(
"All samples with --control-condition should have the same --time-col column in ReporterScreen.samples[time_col]. Check your input ReporterScreen object."
)
else:
self.control_timepoint = control_timepoint[0]
self.n_timepoints = self.n_condits
timepoints = self.screen_selected.samples.sort_values(self.time_column)[
self.time_column
Expand Down

0 comments on commit d538db9

Please sign in to comment.