From ef6b1e2d963b94ccbc57a7d8e64c9279d2b155cb Mon Sep 17 00:00:00 2001 From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:23:45 +0100 Subject: [PATCH 1/9] Pickle save enabled, pickle load now also preserves __deepcopy__ capability --- sbi/inference/posteriors/vi_posterior.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 315a69950..5858bd5bc 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -2,7 +2,7 @@ # under the Affero General Public License v3, see . from copy import deepcopy -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Dict, Iterable, Optional, Union import numpy as np import torch @@ -566,3 +566,17 @@ def map( show_progress_bars=show_progress_bars, force_update=force_update, ) + + def __getstate__(self): + """This method is called when pickling the object.""" + self._optimizer = None + self.__deepcopy__ = None + self._q_build_fn = None + self._q.__deepcopy__ = None + return self.__dict__ + + def __setstate__(self, state_dict: Dict): + self.__dict__ = state_dict + # Restore deepcopy compatibility + make_object_deepcopy_compatible(self) + make_object_deepcopy_compatible(self.q) From 94453354857238c1a2d4a959d3fade8743746d2d Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 21 Feb 2024 05:40:42 +0100 Subject: [PATCH 2/9] Fixing pyright errors --- sbi/samplers/mcmc/slice.py | 2 +- sbi/samplers/vi/vi_divergence_optimizers.py | 6 +++--- sbi/samplers/vi/vi_pyro_flows.py | 15 ++++++++------- sbi/samplers/vi/vi_utils.py | 12 ++++++------ sbi/types.py | 4 +++- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/sbi/samplers/mcmc/slice.py b/sbi/samplers/mcmc/slice.py index 3bca110c4..5dd01535c 100644 --- a/sbi/samplers/mcmc/slice.py +++ b/sbi/samplers/mcmc/slice.py @@ -19,7 +19,7 @@ def __init__( max_width=float("inf"), transforms: Optional[Dict] = None, max_plate_nesting: Optional[int] = None, - jit_compile: Optional[bool] = False, + jit_compile: bool = False, jit_options: Optional[Dict] = None, ignore_jit_warnings: bool = False, ) -> None: diff --git a/sbi/samplers/vi/vi_divergence_optimizers.py b/sbi/samplers/vi/vi_divergence_optimizers.py index e082ed9dd..774f87bd2 100644 --- a/sbi/samplers/vi/vi_divergence_optimizers.py +++ b/sbi/samplers/vi/vi_divergence_optimizers.py @@ -6,7 +6,7 @@ import numpy as np import torch -from pyro.distributions import TransformedDistribution + from torch import Tensor, nn from torch.distributions import Distribution from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop @@ -25,7 +25,7 @@ make_object_deepcopy_compatible, move_all_tensor_to_device, ) -from sbi.types import Array +from sbi.types import Array, PyroTransformedDistribution from sbi.utils import check_prior _VI_method = {} @@ -42,7 +42,7 @@ class DivergenceOptimizer(ABC): def __init__( self, potential_fn: BasePotential, - q: TransformedDistribution, + q: PyroTransformedDistribution, prior: Optional[Distribution] = None, n_particles: int = 256, clip_value: float = 5.0, diff --git a/sbi/samplers/vi/vi_pyro_flows.py b/sbi/samplers/vi/vi_pyro_flows.py index bf241beee..b212f70a8 100644 --- a/sbi/samplers/vi/vi_pyro_flows.py +++ b/sbi/samplers/vi/vi_pyro_flows.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Callable, Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Type import torch from pyro.distributions import transforms +from pyro.distributions.torch_transform import TransformModule from pyro.nn import AutoRegressiveNN, DenseNN from torch import nn from torch.distributions import Distribution, Independent, Normal @@ -19,7 +20,7 @@ def register_transform( - cls: Optional[TorchTransform] = None, + cls: Optional[Type[TorchTransform]] = None, name: Optional[str] = None, inits: Callable = lambda *args, **kwargs: (args, kwargs), ) -> Callable: @@ -278,7 +279,7 @@ class AffineTransform(transforms.AffineTransform): """Trainable version of an Affine transform. This can be used to get diagonal Gaussian approximations.""" - __doc__ += transforms.AffineTransform.__doc__ + __doc__ = transforms.AffineTransform.__doc__ def parameters(self): self.loc.requires_grad_(True) @@ -306,7 +307,7 @@ class LowerCholeskyAffine(transforms.LowerCholeskyAffine): """Trainable version of a Lower Cholesky Affine transform. This can be used to get full Gaussian approximations.""" - __doc__ += transforms.LowerCholeskyAffine.__doc__ + __doc__ = transforms.LowerCholeskyAffine.__doc__ def parameters(self): self.loc.requires_grad_(True) @@ -337,7 +338,7 @@ class TransformedDistribution(torch.distributions.TransformedDistribution): assert __doc__ is not None assert torch.distributions.TransformedDistribution.__doc__ is not None - __doc__ += torch.distributions.TransformedDistribution.__doc__ + __doc__ = torch.distributions.TransformedDistribution.__doc__ def parameters(self): if hasattr(self.base_dist, "parameters"): @@ -354,7 +355,7 @@ def modules(self): def build_flow( event_shape: torch.Size, - link_flow: transforms.Transform, + link_flow: TorchTransform, num_transforms: int = 5, transform: str = "affine_autoregressive", permute: bool = True, @@ -388,7 +389,7 @@ def build_flow( # `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is # a `MultipleIndependent`. additional_dim = ( - len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) + len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None - torch.tensor(event_shape, device=device).item() ) event_shape = torch.Size( diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py index 1855f2c7d..3bbc5be14 100644 --- a/sbi/samplers/vi/vi_utils.py +++ b/sbi/samplers/vi/vi_utils.py @@ -13,14 +13,14 @@ import numpy as np import torch -from pyro.distributions import TransformedDistribution + from pyro.distributions.torch_transform import TransformModule from torch import Tensor -from torch.distributions import Distribution +from torch.distributions import Distribution, TransformedDistribution from torch.distributions.transforms import ComposeTransform, IndependentTransform from torch.nn import Module -from sbi.types import TorchTransform +from sbi.types import TorchTransform, PyroTransformedDistribution def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict: @@ -82,7 +82,7 @@ def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable: pass -def check_parameters_modules_attribute(q: TransformedDistribution) -> None: +def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None: """Checks a parameterized distribution object for valid `parameters` and `modules`. Args: @@ -195,7 +195,7 @@ def add_parameters_module_attributes( def add_parameter_attributes_to_transformed_distribution( - q: TransformedDistribution, + q: PyroTransformedDistribution, ) -> None: """A function that will add `parameters` and `modules` to q automatically, if q is a TransformedDistribution. @@ -224,7 +224,7 @@ def modules(): def adapt_variational_distribution( - q: TransformedDistribution, + q: PyroTransformedDistribution, prior: Distribution, link_transform: Callable, parameters: Iterable = [], diff --git a/sbi/types.py b/sbi/types.py index eff85d926..047787aa7 100644 --- a/sbi/types.py +++ b/sbi/types.py @@ -32,10 +32,11 @@ # Define alias types because otherwise, the documentation by mkdocs became very long and # made the website look ugly. TensorboardSummaryWriter = NewType("Writer", SummaryWriter) -TorchTransform = NewType("torch Transform", Transform) +#TorchTransform = NewType("torch Transform", Transform) TorchModule = NewType("Module", Module) TorchDistribution = NewType("torch Distribution", Distribution) # See PEP 613 for the reason why we need to use TypeAlias here. +TorchTransform: TypeAlias = Transform PyroTransformedDistribution: TypeAlias = TransformedDistribution TorchTensor = NewType("Tensor", Tensor) @@ -46,6 +47,7 @@ "ScalarFloat", "TensorboardSummaryWriter", "TorchModule", + "TorchTransform", "transform_types", "TorchDistribution", "PyroTransformedDistribution", From 98cf0deca7da9d32573f2f0037746eee7f2a0999 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 21 Feb 2024 05:50:29 +0100 Subject: [PATCH 3/9] Formating and one new pyright issue --- sbi/inference/posteriors/vi_posterior.py | 2 +- sbi/samplers/vi/vi_pyro_flows.py | 2 +- sbi/types.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 5858bd5bc..d7f81701d 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -572,7 +572,7 @@ def __getstate__(self): self._optimizer = None self.__deepcopy__ = None self._q_build_fn = None - self._q.__deepcopy__ = None + self._q.__deepcopy__ = None # type: ignore return self.__dict__ def __setstate__(self, state_dict: Dict): diff --git a/sbi/samplers/vi/vi_pyro_flows.py b/sbi/samplers/vi/vi_pyro_flows.py index b212f70a8..4f4e6cfe8 100644 --- a/sbi/samplers/vi/vi_pyro_flows.py +++ b/sbi/samplers/vi/vi_pyro_flows.py @@ -389,7 +389,7 @@ def build_flow( # `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is # a `MultipleIndependent`. additional_dim = ( - len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None + len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None - torch.tensor(event_shape, device=device).item() ) event_shape = torch.Size( diff --git a/sbi/types.py b/sbi/types.py index 047787aa7..ad2adbc14 100644 --- a/sbi/types.py +++ b/sbi/types.py @@ -32,7 +32,7 @@ # Define alias types because otherwise, the documentation by mkdocs became very long and # made the website look ugly. TensorboardSummaryWriter = NewType("Writer", SummaryWriter) -#TorchTransform = NewType("torch Transform", Transform) +# TorchTransform = NewType("torch Transform", Transform) TorchModule = NewType("Module", Module) TorchDistribution = NewType("torch Distribution", Distribution) # See PEP 613 for the reason why we need to use TypeAlias here. From 591f48a11a3ee2615abd8a3a85b9fea3167cab3a Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 21 Feb 2024 06:03:43 +0100 Subject: [PATCH 4/9] Isort --- sbi/samplers/vi/vi_divergence_optimizers.py | 1 - sbi/samplers/vi/vi_utils.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sbi/samplers/vi/vi_divergence_optimizers.py b/sbi/samplers/vi/vi_divergence_optimizers.py index 774f87bd2..fd2466221 100644 --- a/sbi/samplers/vi/vi_divergence_optimizers.py +++ b/sbi/samplers/vi/vi_divergence_optimizers.py @@ -6,7 +6,6 @@ import numpy as np import torch - from torch import Tensor, nn from torch.distributions import Distribution from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py index 3bbc5be14..a35c3f1bb 100644 --- a/sbi/samplers/vi/vi_utils.py +++ b/sbi/samplers/vi/vi_utils.py @@ -13,14 +13,13 @@ import numpy as np import torch - from pyro.distributions.torch_transform import TransformModule from torch import Tensor from torch.distributions import Distribution, TransformedDistribution from torch.distributions.transforms import ComposeTransform, IndependentTransform from torch.nn import Module -from sbi.types import TorchTransform, PyroTransformedDistribution +from sbi.types import PyroTransformedDistribution, TorchTransform def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict: From 5b4a6443986b7490ab9e81cea681d6ddefe3c7ce Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 21 Feb 2024 07:05:07 +0100 Subject: [PATCH 5/9] Deep copy compatibility clashes with pickle compatibility. Requires overriding __deep_copy__ Adding tests --- sbi/inference/posteriors/vi_posterior.py | 16 +++++- tests/vi_test.py | 67 +++++++++++++++++++----- 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index d7f81701d..6c6010440 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -1,6 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +import copy from copy import deepcopy from typing import Callable, Dict, Iterable, Optional, Union @@ -567,10 +568,23 @@ def map( force_update=force_update, ) + def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior": + if memo is None: + memo = {} + # Create a new instance of the class + cls = self.__class__ + result = cls.__new__(cls) + # Add to memo + memo[id(self)] = result + # Copy attributes + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + return result + def __getstate__(self): """This method is called when pickling the object.""" self._optimizer = None - self.__deepcopy__ = None + self.__deepcopy__ = None # type: ignore self._q_build_fn = None self._q.__deepcopy__ = None # type: ignore return self.__dict__ diff --git a/tests/vi_test.py b/tests/vi_test.py index bbfcc613d..019435351 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -5,6 +5,7 @@ from copy import deepcopy +import os import numpy as np import pytest import torch @@ -19,6 +20,13 @@ from sbi.utils import MultipleIndependent from tests.test_utils import check_c2st +class FakePotential(BasePotential): + def __call__(self, theta, **kwargs): + return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32)) + + def allow_iid_x(self) -> bool: + return True + @pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) @@ -190,13 +198,6 @@ def test_deepcopy_support(q: str): num_dim = 2 - class FakePotential(BasePotential): - def __call__(self, theta, **kwargs): - return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32)) - - def allow_iid_x(self) -> bool: - return True - prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior) theta_transform = torch_tf.identity_transform @@ -214,21 +215,59 @@ def allow_iid_x(self) -> bool: assert ( posterior._x == posterior_copy._x ).all(), "Mhh, something with the copy is strange" + + # Try if they are the same + torch.manual_seed(0) + s1 = posterior._q.rsample() + torch.manual_seed(0) + s2 = posterior_copy._q.rsample() + assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange" # Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy. posterior.q.rsample() posterior_copy = deepcopy(posterior) - -def test_vi_posterior_inferface(): +@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf")) +def test_pickle_support(q: str): + """Tests if the VIPosterior can be saved and loaded via pickle. + + Args: + q: Different variational posteriors. + """ num_dim = 2 - class FakePotential(BasePotential): - def __call__(self, theta, **kwargs): - return torch.ones_like(torch.as_tensor(theta[:, 0], dtype=torch.float32)) + prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) + potential_fn = FakePotential(prior=prior) + theta_transform = torch_tf.identity_transform - def allow_iid_x(self) -> bool: - return True + posterior = VIPosterior( + potential_fn, + prior, + theta_transform=theta_transform, + q=q, + ) + posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) + + torch.save(posterior, "posterior.pt") + posterior_loaded = torch.load("posterior.pt") + assert ( + posterior._x == posterior_loaded._x + ).all(), "Mhh, something with the pickled is strange" + + # Try if they are the same + torch.manual_seed(0) + s1 = posterior._q.rsample() + torch.manual_seed(0) + s2 = posterior_loaded._q.rsample() + + os.remove("posterior.pt") + + assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange" + + + +def test_vi_posterior_inferface(): + num_dim = 2 prior = MultivariateNormal(zeros(num_dim), eye(num_dim)) potential_fn = FakePotential(prior=prior) From a2d72d9c3cd55b0e37c9cd1a61e0ee4d350c3b95 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 21 Feb 2024 07:34:59 +0100 Subject: [PATCH 6/9] Forgot formatting tests --- tests/vi_test.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/vi_test.py b/tests/vi_test.py index 019435351..1785cc625 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -3,9 +3,9 @@ from __future__ import annotations +import os from copy import deepcopy -import os import numpy as np import pytest import torch @@ -20,6 +20,7 @@ from sbi.utils import MultipleIndependent from tests.test_utils import check_c2st + class FakePotential(BasePotential): def __call__(self, theta, **kwargs): return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32)) @@ -215,8 +216,8 @@ def test_deepcopy_support(q: str): assert ( posterior._x == posterior_copy._x ).all(), "Mhh, something with the copy is strange" - - # Try if they are the same + + # Try if they are the same torch.manual_seed(0) s1 = posterior._q.rsample() torch.manual_seed(0) @@ -227,10 +228,11 @@ def test_deepcopy_support(q: str): posterior.q.rsample() posterior_copy = deepcopy(posterior) + @pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf")) def test_pickle_support(q: str): """Tests if the VIPosterior can be saved and loaded via pickle. - + Args: q: Different variational posteriors. """ @@ -247,23 +249,22 @@ def test_pickle_support(q: str): q=q, ) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) - + torch.save(posterior, "posterior.pt") posterior_loaded = torch.load("posterior.pt") assert ( posterior._x == posterior_loaded._x ).all(), "Mhh, something with the pickled is strange" - - # Try if they are the same + + # Try if they are the same torch.manual_seed(0) s1 = posterior._q.rsample() torch.manual_seed(0) s2 = posterior_loaded._q.rsample() - + os.remove("posterior.pt") - + assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange" - def test_vi_posterior_inferface(): From 7ba753b9955d7f2a97eddd17900ee8ac35526788 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 21 Feb 2024 07:55:25 +0100 Subject: [PATCH 7/9] Global fake potential, fix shapes --- tests/vi_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/vi_test.py b/tests/vi_test.py index 1785cc625..0e00273fb 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -23,7 +23,7 @@ class FakePotential(BasePotential): def __call__(self, theta, **kwargs): - return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32)) + return torch.ones(theta.shape[0], dtype=torch.float32) def allow_iid_x(self) -> bool: return True From 22fc7ccfc81b2b3a092a21ffe7c068f28fb7704b Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 22 Feb 2024 13:38:56 +0100 Subject: [PATCH 8/9] Restore attributes on reloading --- sbi/inference/posteriors/vi_posterior.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 6c6010440..ea8e95dc9 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -207,7 +207,7 @@ def set_q( modules: List of modules associated with the distribution object. """ - self._q_arg = q + self._q_arg = (q, parameters, modules) if isinstance(q, Distribution): q = adapt_variational_distribution( q, @@ -591,6 +591,9 @@ def __getstate__(self): def __setstate__(self, state_dict: Dict): self.__dict__ = state_dict - # Restore deepcopy compatibility + q = deepcopy(self._q) + # Restore removed attributes + self.set_q(*self._q_arg) + self._q = q make_object_deepcopy_compatible(self) make_object_deepcopy_compatible(self.q) From ccac2bc0ba8db94e94b53ce23f18c7a1d633e895 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 22 Feb 2024 15:52:33 +0100 Subject: [PATCH 9/9] Improve documentation and Assertion errors --- sbi/inference/posteriors/vi_posterior.py | 31 ++++++++++++++++++++++-- tests/vi_test.py | 10 +++++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index ea8e95dc9..6d2748556 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -569,6 +569,18 @@ def map( ) def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior": + """This method is called when using `copy.deepcopy` on the object. + + It defines how the object is copied. We need to overwrite this method, since + the default implementation does use __getstate__ and __setstate__ which we + overwrite to enable pickling (and in particular the necessary modifications are incompatible deep copying). + + Args: + memo (Optional[Dict], optional): Deep copy internal memo. Defaults to None. + + Returns: + VIPosterior: Deep copy of the VIPosterior. + """ if memo is None: memo = {} # Create a new instance of the class @@ -581,8 +593,15 @@ def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior": setattr(result, k, copy.deepcopy(v, memo)) return result - def __getstate__(self): - """This method is called when pickling the object.""" + def __getstate__(self) -> Dict: + """This method is called when pickling the object. + + It defines what is pickled. We need to overwrite this method, + since some parts due not support pickle protocols (e.g. due to local functions, etc.). + + Returns: + Dict: All attributes of the VIPosterior. + """ self._optimizer = None self.__deepcopy__ = None # type: ignore self._q_build_fn = None @@ -590,6 +609,14 @@ def __getstate__(self): return self.__dict__ def __setstate__(self, state_dict: Dict): + """This method is called when unpickling the object. + + Especially, we need to restore the removed attributes and ensure that the + object e.g. remains deep copy compatible. + + Args: + state_dict (Dict): Given state dictionary, we will restore the object from it. + """ self.__dict__ = state_dict q = deepcopy(self._q) # Restore removed attributes diff --git a/tests/vi_test.py b/tests/vi_test.py index 0e00273fb..06ef4ec71 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -211,18 +211,22 @@ def test_deepcopy_support(q: str): ) posterior_copy = deepcopy(posterior) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) - assert posterior._x != posterior_copy._x, "Mhh, something with the copy is strange" + assert ( + posterior._x != posterior_copy._x + ), "Default x attributed of original and copied but modified VIPosterior must be the different, on change (otherwise it is not a deep copy)." posterior_copy = deepcopy(posterior) assert ( posterior._x == posterior_copy._x - ).all(), "Mhh, something with the copy is strange" + ).all(), "Default x attributed of original and copied VIPosterior must be the same." # Try if they are the same torch.manual_seed(0) s1 = posterior._q.rsample() torch.manual_seed(0) s2 = posterior_copy._q.rsample() - assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange" + assert torch.allclose( + s1, s2 + ), "Samples from original and unpickled VIPosterior must be close." # Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy. posterior.q.rsample()