Skip to content

Commit

Permalink
add log_prob() to restricted_prior
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 9, 2022
1 parent fff58fa commit 4f8253e
Showing 1 changed file with 114 additions and 11 deletions.
125 changes: 114 additions & 11 deletions sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
standardizing_net,
z_score_parser,
)
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.user_input_checks import validate_theta_and_x


def build_input_layer(
batch_theta: Tensor = None,
z_score_theta: bool = True,
z_score_theta: Optional[str] = "independent",
embedding_net_theta: nn.Module = nn.Identity(),
) -> nn.Module:
r"""Builds input layer for the `RestrictionEstimator` with option to z-score.
Expand All @@ -33,11 +34,13 @@ def build_input_layer(
Args:
batch_theta: Batch of $\theta$s, used to infer dimensionality and (optional)
z-scoring.
z_score_theta: Whether to z-score $\theta$s passing into the network, can take one of the following:
- `none`, None: do not z-score
- `independent`: z-score each dimension independently
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
- `independent`: z-score each dimension independently.
- `structured`: treat dimensions as related, therefore compute mean and std
over the entire batch, instead of per-dimension.
over the entire batch, instead of per-dimension. Should be used when each
sample is, for example, a time series or an image.
embedding_net_theta: Optional embedding network for $\theta$s.
Returns:
Expand All @@ -59,7 +62,7 @@ def build_classifier(
hidden_features: int = 100,
num_blocks: int = 2,
dropout_probability: float = 0.5,
z_score_theta: bool = True,
z_score_theta: Optional[str] = "independent",
embedding_net_theta: nn.Module = nn.Identity(),
) -> Callable:
"""
Expand All @@ -82,8 +85,13 @@ def build_classifier(
string.
dropout_probability: Dropout probability of the classifier if `model` is
`resnet`.
z_score_theta: Whether to z-score the parameters $\theta$ used to train the
classifier.
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
- `independent`: z-score each dimension independently.
- `structured`: treat dimensions as related, therefore compute mean and std
over the entire batch, instead of per-dimension. Should be used when each
sample is, for example, a time series or an image.
embedding_net_theta: Neural network used to encode the parameters before they
are passed to the classifier.
Expand Down Expand Up @@ -441,6 +449,9 @@ def restrict_prior(
"""
if classifier is None:
classifier = self._classifier

classifier.zero_grad()

return RestrictedPrior(
self._prior,
classifier,
Expand Down Expand Up @@ -511,6 +522,7 @@ def __init__(
self._validation_label = validation_label
self._classifier_thr = None
self._reweigh_factor = None
self.acceptance_rate = None

self.tune_rejection_threshold(allowed_false_negatives)

Expand All @@ -532,6 +544,7 @@ def sample(
sample_shape: Shape = torch.Size(),
show_progress_bars: bool = False,
max_sampling_batch_size: int = 10_000,
save_acceptance_rate: bool = False,
) -> Tensor:
"""
Return samples from the `RestrictedPrior`.
Expand All @@ -544,9 +557,11 @@ def sample(
sample_shape: Shape of the returned samples.
show_progress_bars: Whether or not to show a progressbar during sampling.
max_sampling_batch_size: Batch size for drawing samples from the posterior.
Takes effect only in the second iteration of the loop below, i.e., in case
of leakage or `num_samples>max_sampling_batch_size`. Larger batch size
speeds up sampling.
Takes effect only in the second iteration of the loop below, i.e., in
case of leakage or `num_samples>max_sampling_batch_size`. Larger batch
size speeds up sampling.
save_acceptance_rate: If `True`, the acceptance rate is saved and such that
it can potentially be used later in `log_prob()`.
Returns:
Samples from the `RestrictedPrior`.
Expand Down Expand Up @@ -587,6 +602,9 @@ def sample(
# fixed batch size.
sampling_batch_size = max_sampling_batch_size

if save_acceptance_rate:
self.acceptance_rate = torch.as_tensor(acceptance_rate)

pbar.close()
print(
f"The classifier rejected {(1.0 - acceptance_rate) * 100:.1f}% of all "
Expand All @@ -602,6 +620,90 @@ def sample(

return samples

def log_prob(
self,
theta: Tensor,
norm_restricted_prior: bool = True,
track_gradients: bool = False,
prior_acceptance_params: Optional[dict] = None,
) -> Tensor:
r"""Returns the log-probability of the restricted prior.
Args:
theta: Parameters $\theta$.
norm_restricted_prior: Whether to enforce a normalized restricted prior
density. The normalizing factor is calculated via rejection sampling,
so if you need speedier but unnormalized log probability estimates set
here `norm_restricted_prior=False`. The returned log probability is set
to -∞ outside of the restriceted prior support regardless of this
setting.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
prior_acceptance_params: A `dict` of keyword arguments to override the
default values of `prior_acceptance()`. Possible options are:
`num_rejection_samples`, `force_update`, `show_progress_bars`, and
`rejection_sampling_batch_size`. These parameters only have an effect
if `norm_restricted_prior=True`.
Returns:
`(len(θ),)`-shaped log probability for θ in the support of the restricted
prior, -∞ (corresponding to 0 probability) outside.
"""
theta = ensure_theta_batched(torch.as_tensor(theta))

with torch.set_grad_enabled(track_gradients):

# Evaluate on device, move back to cpu for comparison with prior.
prior_log_prob = self._prior.log_prob(theta)
accepted_by_classifer = self.predict(theta)

masked_log_prob = torch.where(
accepted_by_classifer.bool(),
prior_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32),
)

if prior_acceptance_params is None:
prior_acceptance_params = dict() # use defaults
log_factor = (
torch.log(self.prior_acceptance(**prior_acceptance_params))
if norm_restricted_prior
else 0
)

return masked_log_prob - log_factor

@torch.no_grad()
def prior_acceptance(
self,
num_rejection_samples: int = 10_000,
force_update: bool = False,
show_progress_bars: bool = False,
rejection_sampling_batch_size: int = 10_000,
) -> Tensor:
r"""Return the fraction of prior samples accepted by the classifier.
The factor is estimated from the acceptance probability during rejection
sampling from the prior.
Arguments:
num_rejection_samples: Number of samples used to estimate correction factor.
show_progress_bars: Whether to show a progress bar during sampling.
rejection_sampling_batch_size: Batch size for rejection sampling.
Returns:
Estimated acceptance rate.
"""
if self.acceptance_rate is None or force_update:
_ = self.sample(
(num_rejection_samples,),
show_progress_bars=show_progress_bars,
max_sampling_batch_size=rejection_sampling_batch_size,
save_acceptance_rate=True,
)
return self.acceptance_rate # type:ignore

def predict(self, theta: Tensor) -> Tensor:
r"""
Run classifier to predict whether the parameter set is `invalid` or `valid`.
Expand Down Expand Up @@ -685,6 +787,7 @@ def tune_rejection_threshold(
quantile_index = floor(num_valid * allowed_false_negatives)
self._classifier_thr, _ = torch.kthvalue(clf_probs, quantile_index + 1)

self._classifier_thr = self._classifier_thr.detach()
if print_fp_rate:
self.print_false_positive_rate()

Expand Down

0 comments on commit 4f8253e

Please sign in to comment.