Skip to content

Commit

Permalink
🏗️ Refactor flows module
Browse files Browse the repository at this point in the history
* Split the flows module into many sub-modules
* Spread the bisection bounds of monotonic transformations
* New rotation transformation (RotationTransform)
* New gaussianization transformation (GaussianizationTransform)
* New gaussianization flow (GF)
  • Loading branch information
francois-rozet committed Aug 1, 2023
1 parent cb67fec commit 1d49e3f
Show file tree
Hide file tree
Showing 15 changed files with 1,537 additions and 1,288 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

# Zuko - Normalizing flows in PyTorch

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of `torch` are not sub-classes of `torch.nn.Module`, which means you cannot send their internal tensors to GPU with `.to('cuda')` or retrieve their parameters with `.parameters()`.
Zuko is a Python package that implements normalizing flows in [PyTorch](https://pytorch.org). It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of `torch` are not sub-classes of `torch.nn.Module`, which means you cannot send their internal tensors to GPU with `.to('cuda')` or retrieve their parameters with `.parameters()`. Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express.

To solve this problem, `zuko` defines two abstract classes: `DistributionModule` and `TransformModule`. The former is any `Module` whose forward pass returns a `Distribution` and the latter is any `Module` whose forward pass returns a `Transform`. A normalizing flow is just a `DistributionModule` which contains a list of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend.
To solve these problems, `zuko` defines two concepts, `DistributionModule` and `TransformModule`, which represent recipes for building distributions and transformations, respectively. To condition a distribution or transformation simply means to consider the condition/context as part of the recipe, similar to [Pyro](http://pyro.ai)'s `ConditionalTransformModule`. A normalizing flow is a special `DistributionModule` that contains a sequence of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend.

> In the [Avatar](https://wikipedia.org/wiki/Avatar:_The_Last_Airbender) cartoon, [Zuko](https://wikipedia.org/wiki/Zuko) is a powerful firebender 🔥
Expand Down Expand Up @@ -76,13 +76,16 @@ For more information, check out the documentation at [zuko.readthedocs.io](https

| Class | Year | Reference |
|:-------:|:----:|-----------|
| `GMM` | - | [Gaussian Mixture Model](https://wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model) |
| `NICE` | 2014 | [Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516) |
| `MAF` | 2017 | [Masked Autoregressive Flow for Density Estimation](https://arxiv.org/abs/1705.07057) |
| `NSF` | 2019 | [Neural Spline Flows](https://arxiv.org/abs/1906.04032) |
| `NCSF` | 2020 | [Normalizing Flows on Tori and Spheres](https://arxiv.org/abs/2002.02428) |
| `SOSPF` | 2019 | [Sum-of-Squares Polynomial Flow](https://arxiv.org/abs/1905.02325) |
| `NAF` | 2018 | [Neural Autoregressive Flows](https://arxiv.org/abs/1804.00779) |
| `UNAF` | 2019 | [Unconstrained Monotonic Neural Networks](https://arxiv.org/abs/1908.05164) |
| `CNF` | 2018 | [Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) |
| `GF` | 2020 | [Gaussianization Flows](https://arxiv.org/abs/2003.01941) |

## Contributing

Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
Zuko
====

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of :mod:`torch` are not sub-classes of :class:`torch.nn.Module`, which means you cannot send their internal tensors to GPU with :py:`.to('cuda')` or retrieve their parameters with :py:`.parameters()`.
Zuko is a Python package that implements normalizing flows in `PyTorch <https://pytorch.org>`_. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of :mod:`torch` are not sub-classes of :class:`torch.nn.Module`, which means you cannot send their internal tensors to GPU with :py:`.to('cuda')` or retrieve their parameters with :py:`.parameters()`. Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express.

To solve this problem, :mod:`zuko` defines two abstract classes: :class:`zuko.flows.DistributionModule` and :class:`zuko.flows.TransformModule`. The former is any `Module` whose forward pass returns a `Distribution` and the latter is any `Module` whose forward pass returns a `Transform`. A normalizing flow is just a `DistributionModule` which contains a list of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend.
To solve these problems, :mod:`zuko` defines two concepts, :class:`zuko.flows.core.DistributionModule` and :class:`zuko.flows.core.TransformModule`, which represent recipes for building distributions and transformations, respectively. To condition a distribution or transformation simply means to consider the condition/context as part of the recipe, similar to `Pyro <http://pyro.ai>`_'s `ConditionalTransformModule`. A normalizing flow is a special `DistributionModule` that contains a sequence of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend.

Installation
------------
Expand Down
56 changes: 31 additions & 25 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,43 @@
from zuko.flows import *


torch.set_default_dtype(torch.float64)


def test_flows(tmp_path):
flows = [
GMM(3, 5),
MAF(3, 5),
NSF(3, 5),
SOSPF(3, 5),
NAF(3, 5),
UNAF(3, 5),
GCF(3, 5),
CNF(3, 5),
Fs = [
GMM,
NICE,
MAF,
NSF,
SOSPF,
NAF,
UNAF,
CNF,
GF,
]

for flow in flows:
for F in Fs:
flow = F(3, 5)

# Evaluation of log_prob
x, c = randn(256, 3), randn(5)
log_p = flow(c).log_prob(x)

assert log_p.shape == (256,), flow
assert log_p.requires_grad, flow
assert log_p.shape == (256,), F
assert log_p.requires_grad, F

flow.zero_grad(set_to_none=True)
loss = -log_p.mean()
loss.backward()

for p in flow.parameters():
assert p.grad is not None, flow
assert p.grad is not None, F

# Sampling
x = flow(c).sample((32,))

assert x.shape == (32, 3), flow
assert x.shape == (32, 3), F

# Reparameterization trick
if flow(c).has_rsample:
Expand All @@ -49,15 +55,15 @@ def test_flows(tmp_path):
loss.backward()

for p in flow.parameters():
assert p.grad is not None, flow
assert p.grad is not None, F

# Invertibility
if isinstance(flow, FlowModule):
x, c = randn(256, 3), randn(256, 5)
t = flow(c).transform
z = t.inv(t(x))

assert torch.allclose(x, z, atol=1e-4), flow
assert torch.allclose(x, z, atol=1e-4), F

# Saving
torch.save(flow, tmp_path / 'flow.pth')
Expand All @@ -72,21 +78,21 @@ def test_flows(tmp_path):
torch.manual_seed(seed)
log_p_bis = flow_bis(c).log_prob(x)

assert torch.allclose(log_p, log_p_bis), flow
assert torch.allclose(log_p, log_p_bis), F

# Printing
assert repr(flow), flow
assert repr(flow), F


def test_triangular_transforms():
Ts = [
ElementWiseTransform,
GeneralCouplingTransform,
MaskedAutoregressiveTransform,
partial(MaskedAutoregressiveTransform, passes=2),
NeuralAutoregressiveTransform,
partial(NeuralAutoregressiveTransform, passes=2),
UnconstrainedNeuralAutoregressiveTransform,
GeneralCouplingTransform,
]

for T in Ts:
Expand All @@ -97,16 +103,16 @@ def test_triangular_transforms():

assert y.shape == x.shape, t
assert y.requires_grad, t
assert torch.allclose(t().inv(y), x, atol=1e-4), t
assert torch.allclose(t().inv(y), x, atol=1e-4), T

# With context
t = T(3, 5)
x, c = randn(256, 3), randn(5)
y = t(c)(x)

assert y.shape == x.shape, t
assert y.requires_grad, t
assert torch.allclose(t(c).inv(y), x, atol=1e-4), t
assert y.shape == x.shape, T
assert y.requires_grad, T
assert torch.allclose(t(c).inv(y), x, atol=1e-4), T

# Jacobian
t = T(7)
Expand All @@ -116,5 +122,5 @@ def test_triangular_transforms():
J = torch.autograd.functional.jacobian(t(), x)
ladj = torch.linalg.slogdet(J).logabsdet

assert torch.allclose(t().log_abs_det_jacobian(x, y), ladj, atol=1e-4), t
assert torch.allclose(J.diag().abs().log().sum(), ladj, atol=1e-4), t
assert torch.allclose(t().log_abs_det_jacobian(x, y), ladj, atol=1e-4), T
assert torch.allclose(J.diag().abs().log().sum(), ladj, atol=1e-4), T
8 changes: 6 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from zuko.transforms import *


torch.set_default_dtype(torch.float64)


def test_univariate_transforms():
ts = [
IdentityTransform(),
Expand All @@ -18,6 +21,7 @@ def test_univariate_transforms():
MonotonicAffineTransform(randn(256), randn(256)),
MonotonicRQSTransform(randn(256, 8), randn(256, 8), randn(256, 7)),
MonotonicTransform(lambda x: x**3),
GaussianizationTransform(randn(256, 8), randn(256, 8)),
UnconstrainedMonotonicTransform(lambda x: torch.exp(-x**2) + 1e-2, randn(256)),
SOSPolynomialTransform(randn(256, 2, 4), randn(256)),
]
Expand All @@ -27,7 +31,7 @@ def test_univariate_transforms():
if hasattr(t.domain, 'lower_bound'):
x = torch.linspace(t.domain.lower_bound + 1e-2, t.domain.upper_bound - 1e-2, 256)
else:
x = torch.linspace(-4.999, 4.999, 256)
x = torch.linspace(-5.0, 5.0, 256)

y = t(x)

Expand Down Expand Up @@ -71,8 +75,8 @@ def test_multivariate_transforms():

ts = [
FreeFormJacobianTransform(f, 0.0, 1.0),
LULinearTransform(randn(5, 5)),
PermutationTransform(torch.randperm(5)),
RotationTransform(randn(5, 5)),
]

for t in ts:
Expand Down
Loading

0 comments on commit 1d49e3f

Please sign in to comment.