Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Update get_chain to include log-likelihoods
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Dec 14, 2021
1 parent 2d15c6c commit 4be80fc
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions src/beanmachine/ppl/inference/monte_carlo_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def __init__(
self.samples[rv] = val[:, num_adaptive_samples:]

if logll_results is not None:
logll = merge_dicts(logll_results, 0, stack_not_cat)
if isinstance(logll_results, list):
logll = merge_dicts(logll_results, 0, stack_not_cat)
else:
logll = logll_results
self.log_likelihoods = {}
self.adaptive_log_likelihoods = {}
for rv, val in logll.items():
Expand Down Expand Up @@ -92,7 +95,22 @@ def get_chain(self, chain: int = 0) -> "MonteCarloSamples":
raise IndexError("Please specify a valid chain")

samples = {rv: self.get_variable(rv, True)[[chain]] for rv in self}
new_mcs = MonteCarloSamples(samples, self.num_adaptive_samples)

if self.log_likelihoods is None:
logll = None
else:
logll = {
rv: self.get_log_likelihoods(rv, True)[[chain]]
for rv in self.log_likelihoods
}

new_mcs = MonteCarloSamples(
samples,
self.num_adaptive_samples,
True,
logll,
self.observations,
)
new_mcs.single_chain_view = True

return new_mcs
Expand Down Expand Up @@ -135,6 +153,30 @@ def get_variable(
samples = samples.squeeze(0)
return samples

def get_log_likelihoods(
self,
rv: RVIdentifier,
include_adapt_steps: bool = False,
) -> torch.Tensor:
"""
:returns: log_likelihoods computed during inference for the specified variable
"""

if not isinstance(rv, RVIdentifier):
raise TypeError(
"The key is required to be a random variable "
+ f"but is of type {type(rv).__name__}."
)

logll = self.log_likelihoods[rv]

if include_adapt_steps:
logll = torch.cat([self.adaptive_log_likelihoods[rv], logll], dim=1)

if self.single_chain_view:
logll = logll.squeeze(0)
return logll

def get(
self,
rv: RVIdentifier,
Expand Down

0 comments on commit 4be80fc

Please sign in to comment.