From ef6b1e2d963b94ccbc57a7d8e64c9279d2b155cb Mon Sep 17 00:00:00 2001
From: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com>
Date: Tue, 20 Feb 2024 10:23:45 +0100
Subject: [PATCH 1/9] Pickle save enabled, pickle load now also preserves
__deepcopy__ capability
---
sbi/inference/posteriors/vi_posterior.py | 16 +++++++++++++++-
1 file changed, 15 insertions(+), 1 deletion(-)
diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py
index 315a69950..5858bd5bc 100644
--- a/sbi/inference/posteriors/vi_posterior.py
+++ b/sbi/inference/posteriors/vi_posterior.py
@@ -2,7 +2,7 @@
# under the Affero General Public License v3, see .
from copy import deepcopy
-from typing import Callable, Iterable, Optional, Union
+from typing import Callable, Dict, Iterable, Optional, Union
import numpy as np
import torch
@@ -566,3 +566,17 @@ def map(
show_progress_bars=show_progress_bars,
force_update=force_update,
)
+
+ def __getstate__(self):
+ """This method is called when pickling the object."""
+ self._optimizer = None
+ self.__deepcopy__ = None
+ self._q_build_fn = None
+ self._q.__deepcopy__ = None
+ return self.__dict__
+
+ def __setstate__(self, state_dict: Dict):
+ self.__dict__ = state_dict
+ # Restore deepcopy compatibility
+ make_object_deepcopy_compatible(self)
+ make_object_deepcopy_compatible(self.q)
From 94453354857238c1a2d4a959d3fade8743746d2d Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Wed, 21 Feb 2024 05:40:42 +0100
Subject: [PATCH 2/9] Fixing pyright errors
---
sbi/samplers/mcmc/slice.py | 2 +-
sbi/samplers/vi/vi_divergence_optimizers.py | 6 +++---
sbi/samplers/vi/vi_pyro_flows.py | 15 ++++++++-------
sbi/samplers/vi/vi_utils.py | 12 ++++++------
sbi/types.py | 4 +++-
5 files changed, 21 insertions(+), 18 deletions(-)
diff --git a/sbi/samplers/mcmc/slice.py b/sbi/samplers/mcmc/slice.py
index 3bca110c4..5dd01535c 100644
--- a/sbi/samplers/mcmc/slice.py
+++ b/sbi/samplers/mcmc/slice.py
@@ -19,7 +19,7 @@ def __init__(
max_width=float("inf"),
transforms: Optional[Dict] = None,
max_plate_nesting: Optional[int] = None,
- jit_compile: Optional[bool] = False,
+ jit_compile: bool = False,
jit_options: Optional[Dict] = None,
ignore_jit_warnings: bool = False,
) -> None:
diff --git a/sbi/samplers/vi/vi_divergence_optimizers.py b/sbi/samplers/vi/vi_divergence_optimizers.py
index e082ed9dd..774f87bd2 100644
--- a/sbi/samplers/vi/vi_divergence_optimizers.py
+++ b/sbi/samplers/vi/vi_divergence_optimizers.py
@@ -6,7 +6,7 @@
import numpy as np
import torch
-from pyro.distributions import TransformedDistribution
+
from torch import Tensor, nn
from torch.distributions import Distribution
from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop
@@ -25,7 +25,7 @@
make_object_deepcopy_compatible,
move_all_tensor_to_device,
)
-from sbi.types import Array
+from sbi.types import Array, PyroTransformedDistribution
from sbi.utils import check_prior
_VI_method = {}
@@ -42,7 +42,7 @@ class DivergenceOptimizer(ABC):
def __init__(
self,
potential_fn: BasePotential,
- q: TransformedDistribution,
+ q: PyroTransformedDistribution,
prior: Optional[Distribution] = None,
n_particles: int = 256,
clip_value: float = 5.0,
diff --git a/sbi/samplers/vi/vi_pyro_flows.py b/sbi/samplers/vi/vi_pyro_flows.py
index bf241beee..b212f70a8 100644
--- a/sbi/samplers/vi/vi_pyro_flows.py
+++ b/sbi/samplers/vi/vi_pyro_flows.py
@@ -1,9 +1,10 @@
from __future__ import annotations
-from typing import Callable, Iterable, List, Optional
+from typing import Callable, Iterable, List, Optional, Type
import torch
from pyro.distributions import transforms
+from pyro.distributions.torch_transform import TransformModule
from pyro.nn import AutoRegressiveNN, DenseNN
from torch import nn
from torch.distributions import Distribution, Independent, Normal
@@ -19,7 +20,7 @@
def register_transform(
- cls: Optional[TorchTransform] = None,
+ cls: Optional[Type[TorchTransform]] = None,
name: Optional[str] = None,
inits: Callable = lambda *args, **kwargs: (args, kwargs),
) -> Callable:
@@ -278,7 +279,7 @@ class AffineTransform(transforms.AffineTransform):
"""Trainable version of an Affine transform. This can be used to get diagonal
Gaussian approximations."""
- __doc__ += transforms.AffineTransform.__doc__
+ __doc__ = transforms.AffineTransform.__doc__
def parameters(self):
self.loc.requires_grad_(True)
@@ -306,7 +307,7 @@ class LowerCholeskyAffine(transforms.LowerCholeskyAffine):
"""Trainable version of a Lower Cholesky Affine transform. This can be used to get
full Gaussian approximations."""
- __doc__ += transforms.LowerCholeskyAffine.__doc__
+ __doc__ = transforms.LowerCholeskyAffine.__doc__
def parameters(self):
self.loc.requires_grad_(True)
@@ -337,7 +338,7 @@ class TransformedDistribution(torch.distributions.TransformedDistribution):
assert __doc__ is not None
assert torch.distributions.TransformedDistribution.__doc__ is not None
- __doc__ += torch.distributions.TransformedDistribution.__doc__
+ __doc__ = torch.distributions.TransformedDistribution.__doc__
def parameters(self):
if hasattr(self.base_dist, "parameters"):
@@ -354,7 +355,7 @@ def modules(self):
def build_flow(
event_shape: torch.Size,
- link_flow: transforms.Transform,
+ link_flow: TorchTransform,
num_transforms: int = 5,
transform: str = "affine_autoregressive",
permute: bool = True,
@@ -388,7 +389,7 @@ def build_flow(
# `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is
# a `MultipleIndependent`.
additional_dim = (
- len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0])
+ len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None
- torch.tensor(event_shape, device=device).item()
)
event_shape = torch.Size(
diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py
index 1855f2c7d..3bbc5be14 100644
--- a/sbi/samplers/vi/vi_utils.py
+++ b/sbi/samplers/vi/vi_utils.py
@@ -13,14 +13,14 @@
import numpy as np
import torch
-from pyro.distributions import TransformedDistribution
+
from pyro.distributions.torch_transform import TransformModule
from torch import Tensor
-from torch.distributions import Distribution
+from torch.distributions import Distribution, TransformedDistribution
from torch.distributions.transforms import ComposeTransform, IndependentTransform
from torch.nn import Module
-from sbi.types import TorchTransform
+from sbi.types import TorchTransform, PyroTransformedDistribution
def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict:
@@ -82,7 +82,7 @@ def get_modules(t: Union[TorchTransform, TransformModule]) -> Iterable:
pass
-def check_parameters_modules_attribute(q: TransformedDistribution) -> None:
+def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None:
"""Checks a parameterized distribution object for valid `parameters` and `modules`.
Args:
@@ -195,7 +195,7 @@ def add_parameters_module_attributes(
def add_parameter_attributes_to_transformed_distribution(
- q: TransformedDistribution,
+ q: PyroTransformedDistribution,
) -> None:
"""A function that will add `parameters` and `modules` to q automatically, if q is a
TransformedDistribution.
@@ -224,7 +224,7 @@ def modules():
def adapt_variational_distribution(
- q: TransformedDistribution,
+ q: PyroTransformedDistribution,
prior: Distribution,
link_transform: Callable,
parameters: Iterable = [],
diff --git a/sbi/types.py b/sbi/types.py
index eff85d926..047787aa7 100644
--- a/sbi/types.py
+++ b/sbi/types.py
@@ -32,10 +32,11 @@
# Define alias types because otherwise, the documentation by mkdocs became very long and
# made the website look ugly.
TensorboardSummaryWriter = NewType("Writer", SummaryWriter)
-TorchTransform = NewType("torch Transform", Transform)
+#TorchTransform = NewType("torch Transform", Transform)
TorchModule = NewType("Module", Module)
TorchDistribution = NewType("torch Distribution", Distribution)
# See PEP 613 for the reason why we need to use TypeAlias here.
+TorchTransform: TypeAlias = Transform
PyroTransformedDistribution: TypeAlias = TransformedDistribution
TorchTensor = NewType("Tensor", Tensor)
@@ -46,6 +47,7 @@
"ScalarFloat",
"TensorboardSummaryWriter",
"TorchModule",
+ "TorchTransform",
"transform_types",
"TorchDistribution",
"PyroTransformedDistribution",
From 98cf0deca7da9d32573f2f0037746eee7f2a0999 Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Wed, 21 Feb 2024 05:50:29 +0100
Subject: [PATCH 3/9] Formating and one new pyright issue
---
sbi/inference/posteriors/vi_posterior.py | 2 +-
sbi/samplers/vi/vi_pyro_flows.py | 2 +-
sbi/types.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py
index 5858bd5bc..d7f81701d 100644
--- a/sbi/inference/posteriors/vi_posterior.py
+++ b/sbi/inference/posteriors/vi_posterior.py
@@ -572,7 +572,7 @@ def __getstate__(self):
self._optimizer = None
self.__deepcopy__ = None
self._q_build_fn = None
- self._q.__deepcopy__ = None
+ self._q.__deepcopy__ = None # type: ignore
return self.__dict__
def __setstate__(self, state_dict: Dict):
diff --git a/sbi/samplers/vi/vi_pyro_flows.py b/sbi/samplers/vi/vi_pyro_flows.py
index b212f70a8..4f4e6cfe8 100644
--- a/sbi/samplers/vi/vi_pyro_flows.py
+++ b/sbi/samplers/vi/vi_pyro_flows.py
@@ -389,7 +389,7 @@ def build_flow(
# `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is
# a `MultipleIndependent`.
additional_dim = (
- len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None
+ len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0]) # type: ignore # Since link flow should never be None
- torch.tensor(event_shape, device=device).item()
)
event_shape = torch.Size(
diff --git a/sbi/types.py b/sbi/types.py
index 047787aa7..ad2adbc14 100644
--- a/sbi/types.py
+++ b/sbi/types.py
@@ -32,7 +32,7 @@
# Define alias types because otherwise, the documentation by mkdocs became very long and
# made the website look ugly.
TensorboardSummaryWriter = NewType("Writer", SummaryWriter)
-#TorchTransform = NewType("torch Transform", Transform)
+# TorchTransform = NewType("torch Transform", Transform)
TorchModule = NewType("Module", Module)
TorchDistribution = NewType("torch Distribution", Distribution)
# See PEP 613 for the reason why we need to use TypeAlias here.
From 591f48a11a3ee2615abd8a3a85b9fea3167cab3a Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Wed, 21 Feb 2024 06:03:43 +0100
Subject: [PATCH 4/9] Isort
---
sbi/samplers/vi/vi_divergence_optimizers.py | 1 -
sbi/samplers/vi/vi_utils.py | 3 +--
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/sbi/samplers/vi/vi_divergence_optimizers.py b/sbi/samplers/vi/vi_divergence_optimizers.py
index 774f87bd2..fd2466221 100644
--- a/sbi/samplers/vi/vi_divergence_optimizers.py
+++ b/sbi/samplers/vi/vi_divergence_optimizers.py
@@ -6,7 +6,6 @@
import numpy as np
import torch
-
from torch import Tensor, nn
from torch.distributions import Distribution
from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, Adamax, AdamW, RMSprop
diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py
index 3bbc5be14..a35c3f1bb 100644
--- a/sbi/samplers/vi/vi_utils.py
+++ b/sbi/samplers/vi/vi_utils.py
@@ -13,14 +13,13 @@
import numpy as np
import torch
-
from pyro.distributions.torch_transform import TransformModule
from torch import Tensor
from torch.distributions import Distribution, TransformedDistribution
from torch.distributions.transforms import ComposeTransform, IndependentTransform
from torch.nn import Module
-from sbi.types import TorchTransform, PyroTransformedDistribution
+from sbi.types import PyroTransformedDistribution, TorchTransform
def filter_kwrags_for_func(f: Callable, kwargs: Dict) -> Dict:
From 5b4a6443986b7490ab9e81cea681d6ddefe3c7ce Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Wed, 21 Feb 2024 07:05:07 +0100
Subject: [PATCH 5/9] Deep copy compatibility clashes with pickle
compatibility. Requires overriding __deep_copy__ Adding tests
---
sbi/inference/posteriors/vi_posterior.py | 16 +++++-
tests/vi_test.py | 67 +++++++++++++++++++-----
2 files changed, 68 insertions(+), 15 deletions(-)
diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py
index d7f81701d..6c6010440 100644
--- a/sbi/inference/posteriors/vi_posterior.py
+++ b/sbi/inference/posteriors/vi_posterior.py
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see .
+import copy
from copy import deepcopy
from typing import Callable, Dict, Iterable, Optional, Union
@@ -567,10 +568,23 @@ def map(
force_update=force_update,
)
+ def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior":
+ if memo is None:
+ memo = {}
+ # Create a new instance of the class
+ cls = self.__class__
+ result = cls.__new__(cls)
+ # Add to memo
+ memo[id(self)] = result
+ # Copy attributes
+ for k, v in self.__dict__.items():
+ setattr(result, k, copy.deepcopy(v, memo))
+ return result
+
def __getstate__(self):
"""This method is called when pickling the object."""
self._optimizer = None
- self.__deepcopy__ = None
+ self.__deepcopy__ = None # type: ignore
self._q_build_fn = None
self._q.__deepcopy__ = None # type: ignore
return self.__dict__
diff --git a/tests/vi_test.py b/tests/vi_test.py
index bbfcc613d..019435351 100644
--- a/tests/vi_test.py
+++ b/tests/vi_test.py
@@ -5,6 +5,7 @@
from copy import deepcopy
+import os
import numpy as np
import pytest
import torch
@@ -19,6 +20,13 @@
from sbi.utils import MultipleIndependent
from tests.test_utils import check_c2st
+class FakePotential(BasePotential):
+ def __call__(self, theta, **kwargs):
+ return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32))
+
+ def allow_iid_x(self) -> bool:
+ return True
+
@pytest.mark.slow
@pytest.mark.parametrize("num_dim", (1, 2))
@@ -190,13 +198,6 @@ def test_deepcopy_support(q: str):
num_dim = 2
- class FakePotential(BasePotential):
- def __call__(self, theta, **kwargs):
- return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32))
-
- def allow_iid_x(self) -> bool:
- return True
-
prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
potential_fn = FakePotential(prior=prior)
theta_transform = torch_tf.identity_transform
@@ -214,21 +215,59 @@ def allow_iid_x(self) -> bool:
assert (
posterior._x == posterior_copy._x
).all(), "Mhh, something with the copy is strange"
+
+ # Try if they are the same
+ torch.manual_seed(0)
+ s1 = posterior._q.rsample()
+ torch.manual_seed(0)
+ s2 = posterior_copy._q.rsample()
+ assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange"
# Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy.
posterior.q.rsample()
posterior_copy = deepcopy(posterior)
-
-def test_vi_posterior_inferface():
+@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
+def test_pickle_support(q: str):
+ """Tests if the VIPosterior can be saved and loaded via pickle.
+
+ Args:
+ q: Different variational posteriors.
+ """
num_dim = 2
- class FakePotential(BasePotential):
- def __call__(self, theta, **kwargs):
- return torch.ones_like(torch.as_tensor(theta[:, 0], dtype=torch.float32))
+ prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
+ potential_fn = FakePotential(prior=prior)
+ theta_transform = torch_tf.identity_transform
- def allow_iid_x(self) -> bool:
- return True
+ posterior = VIPosterior(
+ potential_fn,
+ prior,
+ theta_transform=theta_transform,
+ q=q,
+ )
+ posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32)))
+
+ torch.save(posterior, "posterior.pt")
+ posterior_loaded = torch.load("posterior.pt")
+ assert (
+ posterior._x == posterior_loaded._x
+ ).all(), "Mhh, something with the pickled is strange"
+
+ # Try if they are the same
+ torch.manual_seed(0)
+ s1 = posterior._q.rsample()
+ torch.manual_seed(0)
+ s2 = posterior_loaded._q.rsample()
+
+ os.remove("posterior.pt")
+
+ assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange"
+
+
+
+def test_vi_posterior_inferface():
+ num_dim = 2
prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
potential_fn = FakePotential(prior=prior)
From a2d72d9c3cd55b0e37c9cd1a61e0ee4d350c3b95 Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Wed, 21 Feb 2024 07:34:59 +0100
Subject: [PATCH 6/9] Forgot formatting tests
---
tests/vi_test.py | 21 +++++++++++----------
1 file changed, 11 insertions(+), 10 deletions(-)
diff --git a/tests/vi_test.py b/tests/vi_test.py
index 019435351..1785cc625 100644
--- a/tests/vi_test.py
+++ b/tests/vi_test.py
@@ -3,9 +3,9 @@
from __future__ import annotations
+import os
from copy import deepcopy
-import os
import numpy as np
import pytest
import torch
@@ -20,6 +20,7 @@
from sbi.utils import MultipleIndependent
from tests.test_utils import check_c2st
+
class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32))
@@ -215,8 +216,8 @@ def test_deepcopy_support(q: str):
assert (
posterior._x == posterior_copy._x
).all(), "Mhh, something with the copy is strange"
-
- # Try if they are the same
+
+ # Try if they are the same
torch.manual_seed(0)
s1 = posterior._q.rsample()
torch.manual_seed(0)
@@ -227,10 +228,11 @@ def test_deepcopy_support(q: str):
posterior.q.rsample()
posterior_copy = deepcopy(posterior)
+
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
def test_pickle_support(q: str):
"""Tests if the VIPosterior can be saved and loaded via pickle.
-
+
Args:
q: Different variational posteriors.
"""
@@ -247,23 +249,22 @@ def test_pickle_support(q: str):
q=q,
)
posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32)))
-
+
torch.save(posterior, "posterior.pt")
posterior_loaded = torch.load("posterior.pt")
assert (
posterior._x == posterior_loaded._x
).all(), "Mhh, something with the pickled is strange"
-
- # Try if they are the same
+
+ # Try if they are the same
torch.manual_seed(0)
s1 = posterior._q.rsample()
torch.manual_seed(0)
s2 = posterior_loaded._q.rsample()
-
+
os.remove("posterior.pt")
-
+
assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange"
-
def test_vi_posterior_inferface():
From 7ba753b9955d7f2a97eddd17900ee8ac35526788 Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Wed, 21 Feb 2024 07:55:25 +0100
Subject: [PATCH 7/9] Global fake potential, fix shapes
---
tests/vi_test.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/vi_test.py b/tests/vi_test.py
index 1785cc625..0e00273fb 100644
--- a/tests/vi_test.py
+++ b/tests/vi_test.py
@@ -23,7 +23,7 @@
class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
- return torch.ones_like(torch.as_tensor(theta, dtype=torch.float32))
+ return torch.ones(theta.shape[0], dtype=torch.float32)
def allow_iid_x(self) -> bool:
return True
From 22fc7ccfc81b2b3a092a21ffe7c068f28fb7704b Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Thu, 22 Feb 2024 13:38:56 +0100
Subject: [PATCH 8/9] Restore attributes on reloading
---
sbi/inference/posteriors/vi_posterior.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py
index 6c6010440..ea8e95dc9 100644
--- a/sbi/inference/posteriors/vi_posterior.py
+++ b/sbi/inference/posteriors/vi_posterior.py
@@ -207,7 +207,7 @@ def set_q(
modules: List of modules associated with the distribution object.
"""
- self._q_arg = q
+ self._q_arg = (q, parameters, modules)
if isinstance(q, Distribution):
q = adapt_variational_distribution(
q,
@@ -591,6 +591,9 @@ def __getstate__(self):
def __setstate__(self, state_dict: Dict):
self.__dict__ = state_dict
- # Restore deepcopy compatibility
+ q = deepcopy(self._q)
+ # Restore removed attributes
+ self.set_q(*self._q_arg)
+ self._q = q
make_object_deepcopy_compatible(self)
make_object_deepcopy_compatible(self.q)
From ccac2bc0ba8db94e94b53ce23f18c7a1d633e895 Mon Sep 17 00:00:00 2001
From: manuelgloeckler
Date: Thu, 22 Feb 2024 15:52:33 +0100
Subject: [PATCH 9/9] Improve documentation and Assertion errors
---
sbi/inference/posteriors/vi_posterior.py | 31 ++++++++++++++++++++++--
tests/vi_test.py | 10 +++++---
2 files changed, 36 insertions(+), 5 deletions(-)
diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py
index ea8e95dc9..6d2748556 100644
--- a/sbi/inference/posteriors/vi_posterior.py
+++ b/sbi/inference/posteriors/vi_posterior.py
@@ -569,6 +569,18 @@ def map(
)
def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior":
+ """This method is called when using `copy.deepcopy` on the object.
+
+ It defines how the object is copied. We need to overwrite this method, since
+ the default implementation does use __getstate__ and __setstate__ which we
+ overwrite to enable pickling (and in particular the necessary modifications are incompatible deep copying).
+
+ Args:
+ memo (Optional[Dict], optional): Deep copy internal memo. Defaults to None.
+
+ Returns:
+ VIPosterior: Deep copy of the VIPosterior.
+ """
if memo is None:
memo = {}
# Create a new instance of the class
@@ -581,8 +593,15 @@ def __deepcopy__(self, memo: Optional[Dict] = None) -> "VIPosterior":
setattr(result, k, copy.deepcopy(v, memo))
return result
- def __getstate__(self):
- """This method is called when pickling the object."""
+ def __getstate__(self) -> Dict:
+ """This method is called when pickling the object.
+
+ It defines what is pickled. We need to overwrite this method,
+ since some parts due not support pickle protocols (e.g. due to local functions, etc.).
+
+ Returns:
+ Dict: All attributes of the VIPosterior.
+ """
self._optimizer = None
self.__deepcopy__ = None # type: ignore
self._q_build_fn = None
@@ -590,6 +609,14 @@ def __getstate__(self):
return self.__dict__
def __setstate__(self, state_dict: Dict):
+ """This method is called when unpickling the object.
+
+ Especially, we need to restore the removed attributes and ensure that the
+ object e.g. remains deep copy compatible.
+
+ Args:
+ state_dict (Dict): Given state dictionary, we will restore the object from it.
+ """
self.__dict__ = state_dict
q = deepcopy(self._q)
# Restore removed attributes
diff --git a/tests/vi_test.py b/tests/vi_test.py
index 0e00273fb..06ef4ec71 100644
--- a/tests/vi_test.py
+++ b/tests/vi_test.py
@@ -211,18 +211,22 @@ def test_deepcopy_support(q: str):
)
posterior_copy = deepcopy(posterior)
posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32)))
- assert posterior._x != posterior_copy._x, "Mhh, something with the copy is strange"
+ assert (
+ posterior._x != posterior_copy._x
+ ), "Default x attributed of original and copied but modified VIPosterior must be the different, on change (otherwise it is not a deep copy)."
posterior_copy = deepcopy(posterior)
assert (
posterior._x == posterior_copy._x
- ).all(), "Mhh, something with the copy is strange"
+ ).all(), "Default x attributed of original and copied VIPosterior must be the same."
# Try if they are the same
torch.manual_seed(0)
s1 = posterior._q.rsample()
torch.manual_seed(0)
s2 = posterior_copy._q.rsample()
- assert torch.allclose(s1, s2), "Mhh, something with the pickled is strange"
+ assert torch.allclose(
+ s1, s2
+ ), "Samples from original and unpickled VIPosterior must be close."
# Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy.
posterior.q.rsample()