diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 315a69950..6d2748556 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -1,8 +1,9 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +import copy from copy import deepcopy -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Dict, Iterable, Optional, Union import numpy as np import torch @@ -206,7 +207,7 @@ def set_q( modules: List of modules associated with the distribution object. """ - self._q_arg = q + self._q_arg = (q, parameters, modules) if isinstance(q, Distribution): q = adapt_variational_distribution( q, @@ -566,3 +567,60 @@ def map( show_progress_bars=show_progress_bars, force_update=force_update, ) + + def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior": + """This method is called when using `copy.deepcopy` on the object. + + It defines how the object is copied. We need to overwrite this method, since + the default implementation does use __getstate__ and __setstate__ which we + overwrite to enable pickling (and in particular the necessary modifications are incompatible deep copying). + + Args: + memo (Optional[Dict], optional): Deep copy internal memo. Defaults to None. + + Returns: + VIPosterior: Deep copy of the VIPosterior. + """ + if memo is None: + memo = {} + # Create a new instance of the class + cls = self.__class__ + result = cls.__new__(cls) + # Add to memo + memo[id(self)] = result + # Copy attributes + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + return result + + def __getstate__(self) -> Dict: + """This method is called when pickling the object. + + It defines what is pickled. We need to overwrite this method, + since some parts due not support pickle protocols (e.g. due to local functions, etc.). + + Returns: + Dict: All attributes of the VIPosterior. + """ + self._optimizer = None + self.__deepcopy__ = None # type: ignore + self._q_build_fn = None + self._q.__deepcopy__ = None # type: ignore + return self.__dict__ + + def __setstate__(self, state_dict: Dict): + """This method is called when unpickling the object. + + Especially, we need to restore the removed attributes and ensure that the + object e.g. remains deep copy compatible. + + Args: + state_dict (Dict): Given state dictionary, we will restore the object from it. + """ + self.__dict__ = state_dict + q = deepcopy(self._q) + # Restore removed attributes + self.set_q(*self._q_arg) + self._q = q + make_object_deepcopy_compatible(self) + make_object_deepcopy_compatible(self.q) diff --git a/sbi/samplers/mcmc/slice.py b/sbi/samplers/mcmc/slice.py index 3bca110c4..5dd01535c 100644 --- a/sbi/samplers/mcmc/slice.py +++ b/sbi/samplers/mcmc/slice.py @@ -19,7 +19,7 @@ def __init__( max_width=float("inf"), transforms: Optional[Dict] = None, max_plate_nesting: Optional[int] = None, - jit_compile: Optional[bool] = False, + jit_compile: bool = False, jit_options: Optional[Dict] = None, ignore_jit_warnings: bool = False, ) -> None: diff --git a/sbi/samplers/vi/vi_divergence_optimizers.py b/sbi/samplers/vi/vi_divergence_optimizers.py index e082ed9dd..fd2466221 100644 --- a/sbi/samplers/vi/vi_divergence_optimizers.py +++ b/sbi/samplers/vi/vi_divergence_optimizers.py @@ -6,7 +6,6 @@ import numpy as np import torch -from pyro.distributions import TransformedDistribution from torch import Tensor, nn from torch.distributions import Distribution from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop @@ -25,7 +24,7 @@ make_object_deepcopy_compatible, move_all_tensor_to_device, ) -from sbi.types import Array +from sbi.types import Array, PyroTransformedDistribution from sbi.utils import check_prior _VI_method = {} @@ -42,7 +41,7 @@ class DivergenceOptimizer(ABC): def __init__( self, potential_fn: BasePotential, - q: TransformedDistribution, + q: PyroTransformedDistribution, prior: Optional[Distribution] = None, n_particles: int = 256, clip_value: float = 5.0, diff --git a/sbi/samplers/vi/vi_pyro_flows.py b/sbi/samplers/vi/vi_pyro_flows.py index bf241beee..4f4e6cfe8 100644 --- a/sbi/samplers/vi/vi_pyro_flows.py +++ b/sbi/samplers/vi/vi_pyro_flows.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Callable, Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Type import torch from pyro.distributions import transforms +from pyro.distributions.torch_transform import TransformModule from pyro.nn import AutoRegressiveNN, DenseNN from torch import nn from torch.distributions import Distribution, Independent, Normal @@ -19,7 +20,7 @@ def register_transform( - cls: Optional[TorchTransform] = None, + cls: Optional[Type[TorchTransform]] = None, name: Optional[str] = None, inits: Callable = lambda *args, **kwargs: (args, kwargs), ) -> Callable: @@ -278,7 +279,7 @@ class AffineTransform(transforms.AffineTransform): """Trainable version of an Affine transform. This can be used to get diagonal Gaussian approximations.""" - __doc__ += transforms.AffineTransform.__doc__ + __doc__ = transforms.AffineTransform.__doc__ def parameters(self): self.loc.requires_grad_(True) @@ -306,7 +307,7 @@ class LowerCholeskyAffine(transforms.LowerCholeskyAffine): """Trainable version of a Lower Cholesky Affine transform. This can be used to get full Gaussian approximations.""" - __doc__ += transforms.LowerCholeskyAffine.__doc__ + __doc__ = transforms.LowerCholeskyAffine.__doc__ def parameters(self): self.loc.requires_grad_(True) @@ -337,7 +338,7 @@ class TransformedDistribution(torch.distributions.TransformedDistribution): assert __doc__ is not None assert torch.distributions.TransformedDistribution.__doc__ is not None - __doc__ += torch.distributions.TransformedDistribution.__doc__ + __doc__ = torch.distributions.TransformedDistribution.__doc__ def parameters(self): if hasattr(self.base_dist, "parameters"): @@ -354,7 +355,7 @@ def modules(self): def build_flow( event_shape: torch.Size, - link_flow: transforms.Transform, + link_flow: TorchTransform, num_transforms: int = 5, transform: str = "affine_autoregressive", permute: bool = True, @@ -388,7 +389,7 @@ def build_flow( # `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is # a `MultipleIndependent`. additional_dim = ( - len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) + len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None - torch.tensor(event_shape, device=device).item() ) event_shape = torch.Size( diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py index 1855f2c7d..a35c3f1bb 100644 --- a/sbi/samplers/vi/vi_utils.py +++ b/sbi/samplers/vi/vi_utils.py @@ -13,14 +13,13 @@ import numpy as np import torch -from pyro.distributions import TransformedDistribution from pyro.distributions.torch_transform import TransformModule from torch import Tensor -from torch.distributions import Distribution +from torch.distributions import Distribution, TransformedDistribution from torch.distributions.transforms import ComposeTransform, IndependentTransform from torch.nn import Module -from sbi.types import TorchTransform +from sbi.types import PyroTransformedDistribution, TorchTransform def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict: @@ -82,7 +81,7 @@ def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable: pass -def check_parameters_modules_attribute(q: TransformedDistribution) -> None: +def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None: """Checks a parameterized distribution object for valid `parameters` and `modules`. Args: @@ -195,7 +194,7 @@ def add_parameters_module_attributes( def add_parameter_attributes_to_transformed_distribution( - q: TransformedDistribution, + q: PyroTransformedDistribution, ) -> None: """A function that will add `parameters` and `modules` to q automatically, if q is a TransformedDistribution. @@ -224,7 +223,7 @@ def modules(): def adapt_variational_distribution( - q: TransformedDistribution, + q: PyroTransformedDistribution, prior: Distribution, link_transform: Callable, parameters: Iterable = [], diff --git a/sbi/types.py b/sbi/types.py index eff85d926..ad2adbc14 100644 --- a/sbi/types.py +++ b/sbi/types.py @@ -32,10 +32,11 @@ # Define alias types because otherwise, the documentation by mkdocs became very long and # made the website look ugly. TensorboardSummaryWriter = NewType("Writer", SummaryWriter) -TorchTransform = NewType("torch Transform", Transform) +# TorchTransform = NewType("torch Transform", Transform) TorchModule = NewType("Module", Module) TorchDistribution = NewType("torch Distribution", Distribution) # See PEP 613 for the reason why we need to use TypeAlias here. +TorchTransform: TypeAlias = Transform PyroTransformedDistribution: TypeAlias = TransformedDistribution TorchTensor = NewType("Tensor", Tensor) @@ -46,6 +47,7 @@ "ScalarFloat", "TensorboardSummaryWriter", "TorchModule", + "TorchTransform", "transform_types", "TorchDistribution", "PyroTransformedDistribution", diff --git a/tests/vi_test.py b/tests/vi_test.py index bbfcc613d..06ef4ec71 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -3,6 +3,7 @@ from __future__ import annotations +import os from copy import deepcopy import numpy as np @@ -20,6 +21,14 @@ from tests.test_utils import check_c2st +class FakePotential(BasePotential): + def __call__(self, theta, **kwargs): + return torch.ones(theta.shape[0], dtype=torch.float32) + + def allow_iid_x(self) -> bool: + return True + + @pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha")) @@ -190,13 +199,6 @@ def test_deepcopy_support(q: str): num_dim = 2 - class FakePotential(BasePotential): - def __call__(self, theta, **kwargs): - return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32)) - - def allow_iid_x(self) -> bool: - return True - prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior) theta_transform = torch_tf.identity_transform @@ -209,26 +211,68 @@ def allow_iid_x(self) -> bool: ) posterior_copy = deepcopy(posterior) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) - assert posterior._x != posterior_copy._x, "Mhh, something with the copy is strange" + assert ( + posterior._x != posterior_copy._x + ), "Default x attributed of original and copied but modified VIPosterior must be the different, on change (otherwise it is not a deep copy)." posterior_copy = deepcopy(posterior) assert ( posterior._x == posterior_copy._x - ).all(), "Mhh, something with the copy is strange" + ).all(), "Default x attributed of original and copied VIPosterior must be the same." + + # Try if they are the same + torch.manual_seed(0) + s1 = posterior._q.rsample() + torch.manual_seed(0) + s2 = posterior_copy._q.rsample() + assert torch.allclose( + s1, s2 + ), "Samples from original and unpickled VIPosterior must be close." # Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy. posterior.q.rsample() posterior_copy = deepcopy(posterior) -def test_vi_posterior_inferface(): +@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf")) +def test_pickle_support(q: str): + """Tests if the VIPosterior can be saved and loaded via pickle. + + Args: + q: Different variational posteriors. + """ num_dim = 2 - class FakePotential(BasePotential): - def __call__(self, theta, **kwargs): - return torch.ones_like(torch.as_tensor(theta[:, 0], dtype=torch.float32)) + prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) + potential_fn = FakePotential(prior=prior) + theta_transform = torch_tf.identity_transform - def allow_iid_x(self) -> bool: - return True + posterior = VIPosterior( + potential_fn, + prior, + theta_transform=theta_transform, + q=q, + ) + posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) + + torch.save(posterior, "posterior.pt") + posterior_loaded = torch.load("posterior.pt") + assert ( + posterior._x == posterior_loaded._x + ).all(), "Mhh, something with the pickled is strange" + + # Try if they are the same + torch.manual_seed(0) + s1 = posterior._q.rsample() + torch.manual_seed(0) + s2 = posterior_loaded._q.rsample() + + os.remove("posterior.pt") + + assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange" + + +def test_vi_posterior_inferface(): + num_dim = 2 prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior)