Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickle save for VIPosterior #951

Merged
merged 9 commits into from
Feb 22, 2024
62 changes: 60 additions & 2 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import copy
from copy import deepcopy
from typing import Callable, Iterable, Optional, Union
from typing import Callable, Dict, Iterable, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -206,7 +207,7 @@
modules: List of modules associated with the distribution object.

"""
self._q_arg = q
self._q_arg = (q, parameters, modules)
if isinstance(q, Distribution):
q = adapt_variational_distribution(
q,
Expand Down Expand Up @@ -566,3 +567,60 @@
show_progress_bars=show_progress_bars,
force_update=force_update,
)

def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior":
manuelgloeckler marked this conversation as resolved.
Show resolved Hide resolved
"""This method is called when using `copy.deepcopy` on the object.

It defines how the object is copied. We need to overwrite this method, since
the default implementation does use __getstate__ and __setstate__ which we
overwrite to enable pickling (and in particular the necessary modifications are incompatible deep copying).

Args:
memo (Optional[Dict], optional): Deep copy internal memo. Defaults to None.

Returns:
VIPosterior: Deep copy of the VIPosterior.
"""
if memo is None:
memo = {}

Check warning on line 585 in sbi/inference/posteriors/vi_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/vi_posterior.py#L585

Added line #L585 was not covered by tests
# Create a new instance of the class
cls = self.__class__
result = cls.__new__(cls)
# Add to memo
memo[id(self)] = result
# Copy attributes
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result

def __getstate__(self) -> Dict:
"""This method is called when pickling the object.

It defines what is pickled. We need to overwrite this method,
since some parts due not support pickle protocols (e.g. due to local functions, etc.).

Returns:
Dict: All attributes of the VIPosterior.
"""
self._optimizer = None
self.__deepcopy__ = None # type: ignore
self._q_build_fn = None
self._q.__deepcopy__ = None # type: ignore
return self.__dict__

def __setstate__(self, state_dict: Dict):
manuelgloeckler marked this conversation as resolved.
Show resolved Hide resolved
"""This method is called when unpickling the object.

Especially, we need to restore the removed attributes and ensure that the
object e.g. remains deep copy compatible.

Args:
state_dict (Dict): Given state dictionary, we will restore the object from it.
"""
self.__dict__ = state_dict
q = deepcopy(self._q)
# Restore removed attributes
self.set_q(*self._q_arg)
self._q = q
make_object_deepcopy_compatible(self)
make_object_deepcopy_compatible(self.q)
2 changes: 1 addition & 1 deletion sbi/samplers/mcmc/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
max_width=float("inf"),
transforms: Optional[Dict] = None,
max_plate_nesting: Optional[int] = None,
jit_compile: Optional[bool] = False,
jit_compile: bool = False,
manuelgloeckler marked this conversation as resolved.
Show resolved Hide resolved
jit_options: Optional[Dict] = None,
ignore_jit_warnings: bool = False,
) -> None:
Expand Down
5 changes: 2 additions & 3 deletions sbi/samplers/vi/vi_divergence_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
from pyro.distributions import TransformedDistribution
from torch import Tensor, nn
from torch.distributions import Distribution
from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop
Expand All @@ -25,7 +24,7 @@
make_object_deepcopy_compatible,
move_all_tensor_to_device,
)
from sbi.types import Array
from sbi.types import Array, PyroTransformedDistribution
from sbi.utils import check_prior

_VI_method = {}
Expand All @@ -42,7 +41,7 @@ class DivergenceOptimizer(ABC):
def __init__(
self,
potential_fn: BasePotential,
q: TransformedDistribution,
q: PyroTransformedDistribution,
prior: Optional[Distribution] = None,
n_particles: int = 256,
clip_value: float = 5.0,
Expand Down
15 changes: 8 additions & 7 deletions sbi/samplers/vi/vi_pyro_flows.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Callable, Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Type

import torch
from pyro.distributions import transforms
from pyro.distributions.torch_transform import TransformModule
from pyro.nn import AutoRegressiveNN, DenseNN
from torch import nn
from torch.distributions import Distribution, Independent, Normal
Expand All @@ -19,7 +20,7 @@


def register_transform(
cls: Optional[TorchTransform] = None,
cls: Optional[Type[TorchTransform]] = None,
name: Optional[str] = None,
inits: Callable = lambda *args, **kwargs: (args, kwargs),
) -> Callable:
Expand Down Expand Up @@ -278,7 +279,7 @@ class AffineTransform(transforms.AffineTransform):
"""Trainable version of an Affine transform. This can be used to get diagonal
Gaussian approximations."""

__doc__ += transforms.AffineTransform.__doc__
__doc__ = transforms.AffineTransform.__doc__

def parameters(self):
self.loc.requires_grad_(True)
Expand Down Expand Up @@ -306,7 +307,7 @@ class LowerCholeskyAffine(transforms.LowerCholeskyAffine):
"""Trainable version of a Lower Cholesky Affine transform. This can be used to get
full Gaussian approximations."""

__doc__ += transforms.LowerCholeskyAffine.__doc__
__doc__ = transforms.LowerCholeskyAffine.__doc__

def parameters(self):
self.loc.requires_grad_(True)
Expand Down Expand Up @@ -337,7 +338,7 @@ class TransformedDistribution(torch.distributions.TransformedDistribution):

assert __doc__ is not None
assert torch.distributions.TransformedDistribution.__doc__ is not None
__doc__ += torch.distributions.TransformedDistribution.__doc__
__doc__ = torch.distributions.TransformedDistribution.__doc__

def parameters(self):
if hasattr(self.base_dist, "parameters"):
Expand All @@ -354,7 +355,7 @@ def modules(self):

def build_flow(
event_shape: torch.Size,
link_flow: transforms.Transform,
link_flow: TorchTransform,
num_transforms: int = 5,
transform: str = "affine_autoregressive",
permute: bool = True,
Expand Down Expand Up @@ -388,7 +389,7 @@ def build_flow(
# `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is
# a `MultipleIndependent`.
additional_dim = (
len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0])
len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None
- torch.tensor(event_shape, device=device).item()
)
event_shape = torch.Size(
Expand Down
11 changes: 5 additions & 6 deletions sbi/samplers/vi/vi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@

import numpy as np
import torch
from pyro.distributions import TransformedDistribution
from pyro.distributions.torch_transform import TransformModule
from torch import Tensor
from torch.distributions import Distribution
from torch.distributions import Distribution, TransformedDistribution
from torch.distributions.transforms import ComposeTransform, IndependentTransform
from torch.nn import Module

from sbi.types import TorchTransform
from sbi.types import PyroTransformedDistribution, TorchTransform


def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict:
Expand Down Expand Up @@ -82,7 +81,7 @@ def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable:
pass


def check_parameters_modules_attribute(q: TransformedDistribution) -> None:
def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None:
"""Checks a parameterized distribution object for valid `parameters` and `modules`.

Args:
Expand Down Expand Up @@ -195,7 +194,7 @@ def add_parameters_module_attributes(


def add_parameter_attributes_to_transformed_distribution(
q: TransformedDistribution,
q: PyroTransformedDistribution,
) -> None:
"""A function that will add `parameters` and `modules` to q automatically, if q is a
TransformedDistribution.
Expand Down Expand Up @@ -224,7 +223,7 @@ def modules():


def adapt_variational_distribution(
q: TransformedDistribution,
q: PyroTransformedDistribution,
prior: Distribution,
link_transform: Callable,
parameters: Iterable = [],
Expand Down
4 changes: 3 additions & 1 deletion sbi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
# Define alias types because otherwise, the documentation by mkdocs became very long and
# made the website look ugly.
TensorboardSummaryWriter = NewType("Writer", SummaryWriter)
TorchTransform = NewType("torch Transform", Transform)
# TorchTransform = NewType("torch Transform", Transform)
TorchModule = NewType("Module", Module)
TorchDistribution = NewType("torch Distribution", Distribution)
# See PEP 613 for the reason why we need to use TypeAlias here.
TorchTransform: TypeAlias = Transform
PyroTransformedDistribution: TypeAlias = TransformedDistribution
TorchTensor = NewType("Tensor", Tensor)

Expand All @@ -46,6 +47,7 @@
"ScalarFloat",
"TensorboardSummaryWriter",
"TorchModule",
"TorchTransform",
"transform_types",
"TorchDistribution",
"PyroTransformedDistribution",
Expand Down
74 changes: 59 additions & 15 deletions tests/vi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import os
from copy import deepcopy

import numpy as np
Expand All @@ -20,6 +21,14 @@
from tests.test_utils import check_c2st


class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
return torch.ones(theta.shape[0], dtype=torch.float32)

def allow_iid_x(self) -> bool:
return True


@pytest.mark.slow
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha"))
Expand Down Expand Up @@ -190,13 +199,6 @@ def test_deepcopy_support(q: str):

num_dim = 2

class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32))

def allow_iid_x(self) -> bool:
return True

prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
potential_fn = FakePotential(prior=prior)
theta_transform = torch_tf.identity_transform
Expand All @@ -209,26 +211,68 @@ def allow_iid_x(self) -> bool:
)
posterior_copy = deepcopy(posterior)
posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32)))
assert posterior._x != posterior_copy._x, "Mhh, something with the copy is strange"
assert (
posterior._x != posterior_copy._x
), "Default x attributed of original and copied but modified VIPosterior must be the different, on change (otherwise it is not a deep copy)."
posterior_copy = deepcopy(posterior)
assert (
posterior._x == posterior_copy._x
).all(), "Mhh, something with the copy is strange"
).all(), "Default x attributed of original and copied VIPosterior must be the same."

# Try if they are the same
torch.manual_seed(0)
s1 = posterior._q.rsample()
torch.manual_seed(0)
s2 = posterior_copy._q.rsample()
assert torch.allclose(
s1, s2
), "Samples from original and unpickled VIPosterior must be close."

# Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy.
posterior.q.rsample()
posterior_copy = deepcopy(posterior)


def test_vi_posterior_inferface():
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
def test_pickle_support(q: str):
"""Tests if the VIPosterior can be saved and loaded via pickle.

Args:
q: Different variational posteriors.
"""
num_dim = 2

class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
return torch.ones_like(torch.as_tensor(theta[:, 0], dtype=torch.float32))
prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
potential_fn = FakePotential(prior=prior)
theta_transform = torch_tf.identity_transform

def allow_iid_x(self) -> bool:
return True
posterior = VIPosterior(
potential_fn,
prior,
theta_transform=theta_transform,
q=q,
)
posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32)))

torch.save(posterior, "posterior.pt")
posterior_loaded = torch.load("posterior.pt")
assert (
posterior._x == posterior_loaded._x
).all(), "Mhh, something with the pickled is strange"

# Try if they are the same
torch.manual_seed(0)
s1 = posterior._q.rsample()
torch.manual_seed(0)
s2 = posterior_loaded._q.rsample()

os.remove("posterior.pt")

assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange"


def test_vi_posterior_inferface():
num_dim = 2

prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
potential_fn = FakePotential(prior=prior)
Expand Down
Loading