diff --git a/README.md b/README.md index e3fc9a7..1a27a9a 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ # Zuko - Normalizing flows in PyTorch -Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of `torch` are not sub-classes of `torch.nn.Module`, which means you cannot send their internal tensors to GPU with `.to('cuda')` or retrieve their parameters with `.parameters()`. +Zuko is a Python package that implements normalizing flows in [PyTorch](https://pytorch.org). It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of `torch` are not sub-classes of `torch.nn.Module`, which means you cannot send their internal tensors to GPU with `.to('cuda')` or retrieve their parameters with `.parameters()`. Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express. -To solve this problem, `zuko` defines two abstract classes: `DistributionModule` and `TransformModule`. The former is any `Module` whose forward pass returns a `Distribution` and the latter is any `Module` whose forward pass returns a `Transform`. A normalizing flow is just a `DistributionModule` which contains a list of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend. +To solve these problems, `zuko` defines two concepts, `DistributionModule` and `TransformModule`, which represent recipes for building distributions and transformations, respectively. To condition a distribution or transformation simply means to consider the condition/context as part of the recipe, similar to [Pyro](http://pyro.ai)'s `ConditionalTransformModule`. A normalizing flow is a special `DistributionModule` that contains a sequence of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend. > In the [Avatar](https://wikipedia.org/wiki/Avatar:_The_Last_Airbender) cartoon, [Zuko](https://wikipedia.org/wiki/Zuko) is a powerful firebender 🔥 @@ -76,6 +76,8 @@ For more information, check out the documentation at [zuko.readthedocs.io](https | Class | Year | Reference | |:-------:|:----:|-----------| +| `GMM` | - | [Gaussian Mixture Model](https://wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model) | +| `NICE` | 2014 | [Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516) | | `MAF` | 2017 | [Masked Autoregressive Flow for Density Estimation](https://arxiv.org/abs/1705.07057) | | `NSF` | 2019 | [Neural Spline Flows](https://arxiv.org/abs/1906.04032) | | `NCSF` | 2020 | [Normalizing Flows on Tori and Spheres](https://arxiv.org/abs/2002.02428) | @@ -83,6 +85,7 @@ For more information, check out the documentation at [zuko.readthedocs.io](https | `NAF` | 2018 | [Neural Autoregressive Flows](https://arxiv.org/abs/1804.00779) | | `UNAF` | 2019 | [Unconstrained Monotonic Neural Networks](https://arxiv.org/abs/1908.05164) | | `CNF` | 2018 | [Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) | +| `GF` | 2020 | [Gaussianization Flows](https://arxiv.org/abs/2003.01941) | ## Contributing diff --git a/docs/index.rst b/docs/index.rst index a218ecc..d74066a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,9 +7,9 @@ Zuko ==== -Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of :mod:`torch` are not sub-classes of :class:`torch.nn.Module`, which means you cannot send their internal tensors to GPU with :py:`.to('cuda')` or retrieve their parameters with :py:`.parameters()`. +Zuko is a Python package that implements normalizing flows in `PyTorch `_. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the `Distribution` and `Transform` classes of :mod:`torch` are not sub-classes of :class:`torch.nn.Module`, which means you cannot send their internal tensors to GPU with :py:`.to('cuda')` or retrieve their parameters with :py:`.parameters()`. Worse, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express. -To solve this problem, :mod:`zuko` defines two abstract classes: :class:`zuko.flows.DistributionModule` and :class:`zuko.flows.TransformModule`. The former is any `Module` whose forward pass returns a `Distribution` and the latter is any `Module` whose forward pass returns a `Transform`. A normalizing flow is just a `DistributionModule` which contains a list of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend. +To solve these problems, :mod:`zuko` defines two concepts, :class:`zuko.flows.core.DistributionModule` and :class:`zuko.flows.core.TransformModule`, which represent recipes for building distributions and transformations, respectively. To condition a distribution or transformation simply means to consider the condition/context as part of the recipe, similar to `Pyro `_'s `ConditionalTransformModule`. A normalizing flow is a special `DistributionModule` that contains a sequence of `TransformModule` and a base `DistributionModule`. This design allows for flows that behave like distributions while retaining the benefits of `Module`. It also makes the implementations easier to understand and extend. Installation ------------ diff --git a/tests/test_flows.py b/tests/test_flows.py index b39dc59..7af29a2 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -8,37 +8,43 @@ from zuko.flows import * +torch.set_default_dtype(torch.float64) + + def test_flows(tmp_path): - flows = [ - GMM(3, 5), - MAF(3, 5), - NSF(3, 5), - SOSPF(3, 5), - NAF(3, 5), - UNAF(3, 5), - GCF(3, 5), - CNF(3, 5), + Fs = [ + GMM, + NICE, + MAF, + NSF, + SOSPF, + NAF, + UNAF, + CNF, + GF, ] - for flow in flows: + for F in Fs: + flow = F(3, 5) + # Evaluation of log_prob x, c = randn(256, 3), randn(5) log_p = flow(c).log_prob(x) - assert log_p.shape == (256,), flow - assert log_p.requires_grad, flow + assert log_p.shape == (256,), F + assert log_p.requires_grad, F flow.zero_grad(set_to_none=True) loss = -log_p.mean() loss.backward() for p in flow.parameters(): - assert p.grad is not None, flow + assert p.grad is not None, F # Sampling x = flow(c).sample((32,)) - assert x.shape == (32, 3), flow + assert x.shape == (32, 3), F # Reparameterization trick if flow(c).has_rsample: @@ -49,7 +55,7 @@ def test_flows(tmp_path): loss.backward() for p in flow.parameters(): - assert p.grad is not None, flow + assert p.grad is not None, F # Invertibility if isinstance(flow, FlowModule): @@ -57,7 +63,7 @@ def test_flows(tmp_path): t = flow(c).transform z = t.inv(t(x)) - assert torch.allclose(x, z, atol=1e-4), flow + assert torch.allclose(x, z, atol=1e-4), F # Saving torch.save(flow, tmp_path / 'flow.pth') @@ -72,21 +78,21 @@ def test_flows(tmp_path): torch.manual_seed(seed) log_p_bis = flow_bis(c).log_prob(x) - assert torch.allclose(log_p, log_p_bis), flow + assert torch.allclose(log_p, log_p_bis), F # Printing - assert repr(flow), flow + assert repr(flow), F def test_triangular_transforms(): Ts = [ ElementWiseTransform, + GeneralCouplingTransform, MaskedAutoregressiveTransform, partial(MaskedAutoregressiveTransform, passes=2), NeuralAutoregressiveTransform, partial(NeuralAutoregressiveTransform, passes=2), UnconstrainedNeuralAutoregressiveTransform, - GeneralCouplingTransform, ] for T in Ts: @@ -97,16 +103,16 @@ def test_triangular_transforms(): assert y.shape == x.shape, t assert y.requires_grad, t - assert torch.allclose(t().inv(y), x, atol=1e-4), t + assert torch.allclose(t().inv(y), x, atol=1e-4), T # With context t = T(3, 5) x, c = randn(256, 3), randn(5) y = t(c)(x) - assert y.shape == x.shape, t - assert y.requires_grad, t - assert torch.allclose(t(c).inv(y), x, atol=1e-4), t + assert y.shape == x.shape, T + assert y.requires_grad, T + assert torch.allclose(t(c).inv(y), x, atol=1e-4), T # Jacobian t = T(7) @@ -116,5 +122,5 @@ def test_triangular_transforms(): J = torch.autograd.functional.jacobian(t(), x) ladj = torch.linalg.slogdet(J).logabsdet - assert torch.allclose(t().log_abs_det_jacobian(x, y), ladj, atol=1e-4), t - assert torch.allclose(J.diag().abs().log().sum(), ladj, atol=1e-4), t + assert torch.allclose(t().log_abs_det_jacobian(x, y), ladj, atol=1e-4), T + assert torch.allclose(J.diag().abs().log().sum(), ladj, atol=1e-4), T diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b6a4db0..1164a52 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -8,6 +8,9 @@ from zuko.transforms import * +torch.set_default_dtype(torch.float64) + + def test_univariate_transforms(): ts = [ IdentityTransform(), @@ -18,6 +21,7 @@ def test_univariate_transforms(): MonotonicAffineTransform(randn(256), randn(256)), MonotonicRQSTransform(randn(256, 8), randn(256, 8), randn(256, 7)), MonotonicTransform(lambda x: x**3), + GaussianizationTransform(randn(256, 8), randn(256, 8)), UnconstrainedMonotonicTransform(lambda x: torch.exp(-x**2) + 1e-2, randn(256)), SOSPolynomialTransform(randn(256, 2, 4), randn(256)), ] @@ -27,7 +31,7 @@ def test_univariate_transforms(): if hasattr(t.domain, 'lower_bound'): x = torch.linspace(t.domain.lower_bound + 1e-2, t.domain.upper_bound - 1e-2, 256) else: - x = torch.linspace(-4.999, 4.999, 256) + x = torch.linspace(-5.0, 5.0, 256) y = t(x) @@ -71,8 +75,8 @@ def test_multivariate_transforms(): ts = [ FreeFormJacobianTransform(f, 0.0, 1.0), - LULinearTransform(randn(5, 5)), PermutationTransform(torch.randperm(5)), + RotationTransform(randn(5, 5)), ] for t in ts: diff --git a/zuko/flows.py b/zuko/flows.py deleted file mode 100644 index 65cdf29..0000000 --- a/zuko/flows.py +++ /dev/null @@ -1,1126 +0,0 @@ -r"""Parameterized flows and transformations.""" - -__all__ = [ - 'DistributionModule', - 'TransformModule', - 'FlowModule', - 'GMM', - 'ElementWiseTransform', - 'MaskedAutoregressiveTransform', - 'MAF', - 'NSF', - 'NCSF', - 'SOSPF', - 'NeuralAutoregressiveTransform', - 'NAF', - 'UnconstrainedNeuralAutoregressiveTransform', - 'UNAF', - 'GeneralCouplingTransform', - 'GCF', - 'FFJTransform', - 'CNF', -] - -import abc -import torch -import torch.nn as nn - -from functools import partial -from math import ceil, pi, prod -from textwrap import indent -from torch import Tensor, BoolTensor, LongTensor, Size -from torch.distributions import * -from typing import * - -from .distributions import * -from .transforms import * -from .nn import * -from .utils import broadcast, unpack - - -class DistributionModule(nn.Module, abc.ABC): - r"""Abstract distribution module.""" - - @abc.abstractmethod - def forward(c: Tensor = None) -> Distribution: - r""" - Arguments: - c: A context :math:`c`. - - Returns: - A distribution :math:`p(X | c)`. - """ - - pass - - -class TransformModule(nn.Module, abc.ABC): - r"""Abstract transformation module.""" - - @abc.abstractmethod - def forward(c: Tensor = None) -> Transform: - r""" - Arguments: - c: A context :math:`c`. - - Returns: - A transformation :math:`y = f(x | c)`. - """ - - pass - - -class FlowModule(DistributionModule): - r"""Creates a normalizing flow module. - - Arguments: - transforms: A list of transformation modules. - base: A distribution module. - """ - - def __init__( - self, - transforms: Sequence[TransformModule], - base: DistributionModule, - ): - super().__init__() - - self.transforms = nn.ModuleList(transforms) - self.base = base - - def forward(self, c: Tensor = None) -> NormalizingFlow: - r""" - Arguments: - c: A context :math:`c`. - - Returns: - A normalizing flow :math:`p(X | c)`. - """ - - transform = ComposedTransform(*(t(c) for t in self.transforms)) - - if c is None: - base = self.base(c) - else: - base = self.base(c).expand(c.shape[:-1]) - - return NormalizingFlow(transform, base) - - -class Unconditional(nn.Module): - r"""Creates a module that registers the positional arguments of a function. - The function is evaluated during the forward pass and the result is returned. - - Arguments: - meta: An arbitrary function. - args: The positional tensor arguments passed to `meta`. - buffer: Whether tensors are registered as buffer or parameter. - kwargs: The keyword arguments passed to `meta`. - """ - - def __init__( - self, - meta: Callable[..., Any], - *args: Tensor, - buffer: bool = False, - **kwargs, - ): - super().__init__() - - self.meta = meta - - for i, arg in enumerate(args): - if buffer: - self.register_buffer(f'_{i}', arg) - else: - self.register_parameter(f'_{i}', nn.Parameter(arg)) - - self.kwargs = kwargs - - def __repr__(self) -> str: - return repr(self.forward()) - - def forward(self, c: Tensor = None) -> Any: - return self.meta( - *self._parameters.values(), - *self._buffers.values(), - **self.kwargs, - ) - - -class Parameters(nn.ParameterList): - r"""Creates a list of parameters.""" - - def extra_repr(self) -> str: - lines = [ - f'({i}): Tensor(shape={tuple(p.shape)})' - for i, p in enumerate(self) - ] - - return indent('\n'.join(lines), ' ') - - -class GMM(DistributionModule): - r"""Creates a Gaussian mixture model (GMM). - - .. math:: p(X | c) = \sum_{i = 1}^K w_i(c) \, \mathcal{N}(X | \mu_i(c), \Sigma_i(c)) - - Arguments: - features: The number of features. - context: The number of context features. - components: The number of components :math:`K` in the mixture. - kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. - """ - - def __init__( - self, - features: int, - context: int = 0, - components: int = 2, - **kwargs, - ): - super().__init__() - - shapes = [ - (components,), # probabilities - (components, features), # mean - (components, features), # diagonal - (components, features * (features - 1) // 2), # off diagonal - ] - - self.shapes = shapes - self.total = sum(prod(s) for s in shapes) - - if context > 0: - self.hyper = MLP(context, self.total, **kwargs) - else: - self.phi = Parameters(torch.randn(*s) for s in shapes) - - def forward(self, c: Tensor = None) -> Distribution: - if c is None: - phi = self.phi - else: - phi = self.hyper(c) - phi = unpack(phi, self.shapes) - - logits, loc, diag, tril = phi - - scale = torch.diag_embed(diag.exp() + 1e-5) - mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1) - scale = torch.masked_scatter(scale, mask, tril) - - return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits) - - -class ElementWiseTransform(TransformModule): - r"""Creates an element-wise transformation. - - Arguments: - features: The number of features. - context: The number of context features. - univariate: The univariate transformation constructor. - shapes: The shapes of the univariate transformation parameters. - kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. - - Example: - >>> t = ElementWiseTransform(3, 4) - >>> t - ElementWiseTransform( - (base): MonotonicAffineTransform() - (hyper): MLP( - (0): Linear(in_features=4, out_features=64, bias=True) - (1): ReLU() - (2): Linear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): Linear(in_features=64, out_features=6, bias=True) - ) - ) - >>> x = torch.randn(3) - >>> x - tensor([2.1983, -1.3182, 0.0329]) - >>> c = torch.randn(4) - >>> y = t(c)(x) - >>> t(c).inv(y) - tensor([2.1983, -1.3182, 0.0329]) - """ - - def __init__( - self, - features: int, - context: int = 0, - univariate: Callable[..., Transform] = MonotonicAffineTransform, - shapes: Sequence[Size] = ((), ()), - **kwargs, - ): - super().__init__() - - self.univariate = univariate - self.shapes = shapes - self.total = sum(prod(s) for s in shapes) - - if context > 0: - self.hyper = MLP(context, features * self.total, **kwargs) - else: - self.phi = Parameters(torch.randn(features, *s) for s in shapes) - - def extra_repr(self) -> str: - base = self.univariate(*map(torch.randn, self.shapes)) - - return '\n'.join([ - f'(base): {base}', - ]) - - def forward(self, c: Tensor = None) -> Transform: - if c is None: - phi = self.phi - else: - phi = self.hyper(c) - phi = phi.unflatten(-1, (-1, self.total)) - phi = unpack(phi, self.shapes) - - return DependentTransform(self.univariate(*phi), 1) - - -class MaskedAutoregressiveTransform(TransformModule): - r"""Creates a masked autoregressive transformation. - - References: - | Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017) - | https://arxiv.org/abs/1705.07057 - - Arguments: - features: The number of features. - context: The number of context features. - passes: The number of sequential passes for the inverse transformation. If - :py:`None`, use the number of features instead, making the transformation - fully autoregressive. Coupling corresponds to :py:`passes=2`. - order: The feature ordering. If :py:`None`, use :py:`range(features)` instead. - univariate: The univariate transformation constructor. - shapes: The shapes of the univariate transformation parameters. - kwargs: Keyword arguments passed to :class:`zuko.nn.MaskedMLP`. - - Example: - >>> t = MaskedAutoregressiveTransform(3, 4) - >>> t - MaskedAutoregressiveTransform( - (base): MonotonicAffineTransform() - (order): [0, 1, 2] - (hyper): MaskedMLP( - (0): MaskedLinear(in_features=7, out_features=64, bias=True) - (1): ReLU() - (2): MaskedLinear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): MaskedLinear(in_features=64, out_features=6, bias=True) - ) - ) - >>> x = torch.randn(3) - >>> x - tensor([-0.9485, 1.5290, 0.2018]) - >>> c = torch.randn(4) - >>> y = t(c)(x) - >>> t(c).inv(y) - tensor([-0.9485, 1.5290, 0.2018]) - """ - - def __init__( - self, - features: int, - context: int = 0, - passes: int = None, - order: LongTensor = None, - univariate: Callable[..., Transform] = MonotonicAffineTransform, - shapes: Sequence[Size] = ((), ()), - **kwargs, - ): - super().__init__() - - # Univariate transformation - self.univariate = univariate - self.shapes = shapes - self.total = sum(prod(s) for s in shapes) - - # Adjacency - self.register_buffer('order', None) - - if passes is None: - passes = features - - if order is None: - order = torch.arange(features) - else: - order = torch.as_tensor(order) - - self.passes = min(max(passes, 1), features) - self.order = torch.div(order, ceil(features / self.passes), rounding_mode='floor') - - in_order = torch.cat((self.order, torch.full((context,), -1))) - out_order = torch.repeat_interleave(self.order, self.total) - adjacency = out_order[:, None] > in_order - - # Hyper network - self.hyper = MaskedMLP(adjacency, **kwargs) - - def extra_repr(self) -> str: - base = self.univariate(*map(torch.randn, self.shapes)) - order = self.order.tolist() - - if len(order) > 10: - order = order[:5] + [...] + order[-5:] - order = str(order).replace('Ellipsis', '...') - - return '\n'.join([ - f'(base): {base}', - f'(order): {order}', - ]) - - def meta(self, c: Tensor, x: Tensor) -> Transform: - if c is not None: - x = torch.cat(broadcast(x, c, ignore=1), dim=-1) - - phi = self.hyper(x) - phi = phi.unflatten(-1, (-1, self.total)) - phi = unpack(phi, self.shapes) - - return DependentTransform(self.univariate(*phi), 1) - - def forward(self, c: Tensor = None) -> Transform: - return AutoregressiveTransform(partial(self.meta, c), self.passes) - - -class MAF(FlowModule): - r"""Creates a masked autoregressive flow (MAF). - - References: - | Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017) - | https://arxiv.org/abs/1705.07057 - - Arguments: - features: The number of features. - context: The number of context features. - transforms: The number of autoregressive transformations. - randperm: Whether features are randomly permuted between transformations or not. - If :py:`False`, features are in ascending (descending) order for even - (odd) transformations. - kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`. - - Example: - >>> flow = MAF(3, 4, transforms=3) - >>> flow - MAF( - (transforms): ModuleList( - (0): MaskedAutoregressiveTransform( - (base): MonotonicAffineTransform() - (order): [0, 1, 2] - (hyper): MaskedMLP( - (0): MaskedLinear(in_features=7, out_features=64, bias=True) - (1): ReLU() - (2): MaskedLinear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): MaskedLinear(in_features=64, out_features=6, bias=True) - ) - ) - (1): MaskedAutoregressiveTransform( - (base): MonotonicAffineTransform() - (order): [2, 1, 0] - (hyper): MaskedMLP( - (0): MaskedLinear(in_features=7, out_features=64, bias=True) - (1): ReLU() - (2): MaskedLinear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): MaskedLinear(in_features=64, out_features=6, bias=True) - ) - ) - (2): MaskedAutoregressiveTransform( - (base): MonotonicAffineTransform() - (order): [0, 1, 2] - (hyper): MaskedMLP( - (0): MaskedLinear(in_features=7, out_features=64, bias=True) - (1): ReLU() - (2): MaskedLinear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): MaskedLinear(in_features=64, out_features=6, bias=True) - ) - ) - ) - (base): DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])) - ) - >>> c = torch.randn(4) - >>> x = flow(c).sample() - >>> x - tensor([-1.7154, -0.4401, 0.7505]) - >>> flow(c).log_prob(x) - tensor(-4.4630, grad_fn=) - """ - - def __init__( - self, - features: int, - context: int = 0, - transforms: int = 3, - randperm: bool = False, - **kwargs, - ): - orders = [ - torch.arange(features), - torch.flipud(torch.arange(features)), - ] - - transforms = [ - MaskedAutoregressiveTransform( - features=features, - context=context, - order=torch.randperm(features) if randperm else orders[i % 2], - **kwargs, - ) - for i in range(transforms) - ] - - base = Unconditional( - DiagNormal, - torch.zeros(features), - torch.ones(features), - buffer=True, - ) - - super().__init__(transforms, base) - - -class NSF(MAF): - r"""Creates a neural spline flow (NSF) with monotonic rational-quadratic spline - transformations. - - Note: - By default, transformations are fully autoregressive. Coupling transformations - can be obtained by setting :py:`passes=2`. - - References: - | Neural Spline Flows (Durkan et al., 2019) - | https://arxiv.org/abs/1906.04032 - - Arguments: - features: The number of features. - context: The number of context features. - bins: The number of bins :math:`K`. - kwargs: Keyword arguments passed to :class:`MAF`. - """ - - def __init__( - self, - features: int, - context: int = 0, - bins: int = 8, - **kwargs, - ): - super().__init__( - features=features, - context=context, - univariate=MonotonicRQSTransform, - shapes=[(bins,), (bins,), (bins - 1,)], - **kwargs, - ) - - -class NCSF(NSF): - r"""Creates a neural circular spline flow (NCSF). - - Note: - Features are assumed to lie in the half-open interval :math:`[-\pi, \pi[`. - - References: - | Normalizing Flows on Tori and Spheres (Rezende et al., 2020) - | https://arxiv.org/abs/2002.02428 - - Arguments: - features: The number of features. - context: The number of context features. - kwargs: Keyword arguments passed to :class:`NSF`. - """ - - def __init__( - self, - features: int, - context: int = 0, - **kwargs, - ): - super().__init__(features, context, **kwargs) - - for t in self.transforms: - t.univariate = self.circular_spline - - self.base = Unconditional( - BoxUniform, - torch.full((features,), -pi - 1e-5), - torch.full((features,), pi + 1e-5), - buffer=True, - ) - - @staticmethod - def circular_spline(*args) -> Transform: - return ComposedTransform( - CircularShiftTransform(bound=pi), - MonotonicRQSTransform(*args, bound=pi), - ) - - -class SOSPF(MAF): - r"""Creates a sum-of-squares polynomial flow (SOSPF). - - References: - | Sum-of-Squares Polynomial Flow (Jaini et al., 2019) - | https://arxiv.org/abs/1905.02325 - - Arguments: - features: The number of features. - context: The number of context features. - degree: The degree :math:`L` of polynomials. - polynomials: The number of polynomials :math:`K`. - kwargs: Keyword arguments passed to :class:`MAF`. - """ - - def __init__( - self, - features: int, - context: int = 0, - degree: int = 3, - polynomials: int = 2, - **kwargs, - ): - super().__init__( - features=features, - context=context, - univariate=SOSPolynomialTransform, - shapes=[(polynomials, degree + 1), ()], - **kwargs, - ) - - for i in reversed(range(len(self.transforms))): - self.transforms.insert(i, Unconditional(SoftclipTransform, bound=6.0)) - - -class NeuralAutoregressiveTransform(MaskedAutoregressiveTransform): - r"""Creates a neural autoregressive transformation. - - The monotonic neural network is parametrized by its internal positive weights, - which are independent of the features and context. To modulate its behavior, it - receives as input a signal that is autoregressively dependent on the features - and context. - - References: - | Neural Autoregressive Flows (Huang et al., 2018) - | https://arxiv.org/abs/1804.00779 - - Arguments: - features: The number of features. - context: The number of context features. - signal: The number of signal features of the monotonic network. - network: Keyword arguments passed to :class:`zuko.nn.MonotonicMLP`. - kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`. - - Example: - >>> t = NeuralAutoregressiveTransform(3, 4) - >>> t - NeuralAutoregressiveTransform( - (base): MonotonicTransform() - (order): [0, 1, 2] - (hyper): MaskedMLP( - (0): MaskedLinear(in_features=7, out_features=64, bias=True) - (1): ReLU() - (2): MaskedLinear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): MaskedLinear(in_features=64, out_features=24, bias=True) - ) - (network): MonotonicMLP( - (0): MonotonicLinear(in_features=9, out_features=64, bias=True) - (1): TwoWayELU(alpha=1.0) - (2): MonotonicLinear(in_features=64, out_features=64, bias=True) - (3): TwoWayELU(alpha=1.0) - (4): MonotonicLinear(in_features=64, out_features=1, bias=True) - ) - ) - >>> x = torch.randn(3) - >>> x - tensor([-2.3267, 1.4581, -1.6776]) - >>> c = torch.randn(4) - >>> y = t(c)(x) - >>> t(c).inv(y) - tensor([-2.3267, 1.4581, -1.6776]) - """ - - def __init__( - self, - features: int, - context: int = 0, - signal: int = 8, - network: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__( - features=features, - context=context, - univariate=self.univariate, - shapes=[(signal,)], - **kwargs, - ) - - self.network = MonotonicMLP(1 + signal, 1, **network) - - def f(self, signal: Tensor, x: Tensor) -> Tensor: - return self.network( - torch.cat(broadcast(x[..., None], signal, ignore=1), dim=-1) - ).squeeze(dim=-1) - - def univariate(self, signal: Tensor) -> Transform: - return MonotonicTransform( - f=partial(self.f, signal), - phi=(signal, *self.network.parameters()), - ) - - -class NAF(FlowModule): - r"""Creates a neural autoregressive flow (NAF). - - References: - | Neural Autoregressive Flows (Huang et al., 2018) - | https://arxiv.org/abs/1804.00779 - - Arguments: - features: The number of features. - context: The number of context features. - transforms: The number of autoregressive transformations. - randperm: Whether features are randomly permuted between transformations or not. - If :py:`False`, features are in ascending (descending) order for even - (odd) transformations. - unconstrained: Whether to use unconstrained or regular monotonic networks. - kwargs: Keyword arguments passed to :class:`NeuralAutoregressiveTransform`. - """ - - def __init__( - self, - features: int, - context: int = 0, - transforms: int = 3, - randperm: bool = False, - **kwargs, - ): - orders = [ - torch.arange(features), - torch.flipud(torch.arange(features)), - ] - - transforms = [ - NeuralAutoregressiveTransform( - features=features, - context=context, - order=torch.randperm(features) if randperm else orders[i % 2], - **kwargs, - ) - for i in range(transforms) - ] - - for i in reversed(range(len(transforms))): - transforms.insert(i, Unconditional(SoftclipTransform, bound=6.0)) - - base = Unconditional( - DiagNormal, - torch.zeros(features), - torch.ones(features), - buffer=True, - ) - - super().__init__(transforms, base) - - -class UnconstrainedNeuralAutoregressiveTransform(MaskedAutoregressiveTransform): - r"""Creates an unconstrained neural autoregressive transformation. - - The integrand neural network is parametrized by its internal weights, which are - independent of the features and context. To modulate its behavior, it receives as - input a signal that is autoregressively dependent on the features and context. The - integration constant has the same dependencies as the signal. - - References: - | Unconstrained Monotonic Neural Networks (Wehenkel et al., 2019) - | https://arxiv.org/abs/1908.05164 - - Arguments: - features: The number of features. - context: The number of context features. - signal: The number of signal features of the integrand network. - network: Keyword arguments passed to :class:`zuko.nn.MLP`. - kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`. - - Example: - >>> t = UnconstrainedNeuralAutoregressiveTransform(3, 4) - >>> t - UnconstrainedNeuralAutoregressiveTransform( - (base): UnconstrainedMonotonicTransform() - (order): [0, 1, 2] - (hyper): MaskedMLP( - (0): MaskedLinear(in_features=7, out_features=64, bias=True) - (1): ReLU() - (2): MaskedLinear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): MaskedLinear(in_features=64, out_features=27, bias=True) - ) - (integrand): MLP( - (0): Linear(in_features=9, out_features=64, bias=True) - (1): ELU(alpha=1.0) - (2): Linear(in_features=64, out_features=64, bias=True) - (3): ELU(alpha=1.0) - (4): Linear(in_features=64, out_features=1, bias=True) - (5): Softplus(beta=1, threshold=20) - ) - ) - >>> x = torch.randn(3) - >>> x - tensor([-0.0103, -1.0871, -0.0667]) - >>> c = torch.randn(4) - >>> y = t(c)(x) - >>> t(c).inv(y) - tensor([-0.0103, -1.0871, -0.0667]) - """ - - def __init__( - self, - features: int, - context: int = 0, - signal: int = 8, - network: Dict[str, Any] = {}, - **kwargs, - ): - super().__init__( - features=features, - context=context, - univariate=self.univariate, - shapes=[(signal,), ()], - **kwargs, - ) - - network.setdefault('activation', nn.ELU) - - self.integrand = MLP(1 + signal, 1, **network) - self.integrand.add_module(str(len(self.integrand)), nn.Softplus()) - - def g(self, signal: Tensor, x: Tensor) -> Tensor: - return self.integrand( - torch.cat(broadcast(x[..., None], signal, ignore=1), dim=-1) - ).squeeze(dim=-1) - - def univariate(self, signal: Tensor, constant: Tensor) -> Transform: - return UnconstrainedMonotonicTransform( - g=partial(self.g, signal), - C=constant, - phi=(signal, *self.integrand.parameters()), - ) - - -class UNAF(FlowModule): - r"""Creates an unconstrained neural autoregressive flow (UNAF). - - References: - | Unconstrained Monotonic Neural Networks (Wehenkel et al., 2019) - | https://arxiv.org/abs/1908.05164 - - Arguments: - features: The number of features. - context: The number of context features. - transforms: The number of autoregressive transformations. - randperm: Whether features are randomly permuted between transformations or not. - If :py:`False`, features are in ascending (descending) order for even - (odd) transformations. - kwargs: Keyword arguments passed to :class:`UnconstrainedNeuralAutoregressiveTransform`. - """ - - def __init__( - self, - features: int, - context: int = 0, - transforms: int = 3, - randperm: bool = False, - **kwargs, - ): - orders = [ - torch.arange(features), - torch.flipud(torch.arange(features)), - ] - - transforms = [ - UnconstrainedNeuralAutoregressiveTransform( - features=features, - context=context, - order=torch.randperm(features) if randperm else orders[i % 2], - **kwargs, - ) - for i in range(transforms) - ] - - for i in reversed(range(len(transforms))): - transforms.insert(i, Unconditional(SoftclipTransform, bound=6.0)) - - base = Unconditional( - DiagNormal, - torch.zeros(features), - torch.ones(features), - buffer=True, - ) - - super().__init__(transforms, base) - - -class GeneralCouplingTransform(TransformModule): - r"""Creates a general coupling transformation. - - References: - | NICE: Non-linear Independent Components Estimation (Dinh et al., 2014) - | https://arxiv.org/abs/1410.8516 - - Arguments: - features: The number of features. - context: The number of context features. - mask: The coupling mask. If :py:`None`, use a checkered mask. - univariate: The univariate transformation constructor. - shapes: The shapes of the univariate transformation parameters. - kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. - - Example: - >>> t = GeneralCouplingTransform(3, 4) - >>> t - GeneralCouplingTransform( - (base): MonotonicAffineTransform() - (mask): [0, 1, 0] - (hyper): MLP( - (0): Linear(in_features=5, out_features=64, bias=True) - (1): ReLU() - (2): Linear(in_features=64, out_features=64, bias=True) - (3): ReLU() - (4): Linear(in_features=64, out_features=4, bias=True) - ) - ) - >>> x = torch.randn(3) - >>> x - tensor([-0.8743, 0.6232, 1.2439]) - >>> c = torch.randn(4) - >>> y = t(c)(x) - >>> t(c).inv(y) - tensor([-0.8743, 0.6232, 1.2439]) - """ - - def __init__( - self, - features: int, - context: int = 0, - mask: BoolTensor = None, - univariate: Callable[..., Transform] = MonotonicAffineTransform, - shapes: Sequence[Size] = ((), ()), - **kwargs, - ): - super().__init__() - - # Univariate transformation - self.univariate = univariate - self.shapes = shapes - self.total = sum(prod(s) for s in shapes) - - # Mask - self.register_buffer('mask', None) - - if mask is None: - self.mask = torch.arange(features) % 2 == 1 - else: - self.mask = mask - - features_a = self.mask.sum().item() - features_b = features - features_a - - # Hyper network - self.hyper = MLP(features_a + context, features_b * self.total, **kwargs) - - def extra_repr(self) -> str: - base = self.univariate(*map(torch.randn, self.shapes)) - mask = self.mask.int().tolist() - - if len(mask) > 10: - mask = mask[:5] + [...] + mask[-5:] - mask = str(mask).replace('Ellipsis', '...') - - return '\n'.join([ - f'(base): {base}', - f'(mask): {mask}', - ]) - - def meta(self, c: Tensor, x: Tensor) -> Transform: - if c is not None: - x = torch.cat(broadcast(x, c, ignore=1), dim=-1) - - phi = self.hyper(x) - phi = phi.unflatten(-1, (-1, self.total)) - phi = unpack(phi, self.shapes) - - return DependentTransform(self.univariate(*phi), 1) - - def forward(self, c: Tensor = None) -> Transform: - return CouplingTransform(partial(self.meta, c), self.mask) - - -class GCF(FlowModule): - r"""Creates a general coupling flow (GCF). - - Arguments: - features: The number of features. - context: The number of context features. - transforms: The number of coupling transformations. - randmask: Whether random coupling masks are used or not. If :py:`False`, - use alternating checkered masks. - kwargs: Keyword arguments passed to :class:`GeneralCouplingTransform`. - """ - - def __init__( - self, - features: int, - context: int = 0, - transforms: int = 3, - randmask: bool = False, - **kwargs, - ): - temp = [] - - for i in range(transforms): - if randmask: - mask = torch.randperm(features) % 2 == i % 2 - else: - mask = torch.arange(features) % 2 == i % 2 - - temp.append( - GeneralCouplingTransform( - features=features, - context=context, - mask=mask, - **kwargs, - ) - ) - - base = Unconditional( - DiagNormal, - torch.zeros(features), - torch.ones(features), - buffer=True, - ) - - super().__init__(temp, base) - - -class FFJTransform(TransformModule): - r"""Creates a free-form Jacobian (FFJ) transformation. - - References: - | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) - | https://arxiv.org/abs/1810.01367 - - Arguments: - features: The number of features. - context: The number of context features. - freqs: The number of time embedding frequencies. - exact: Whether the exact log-determinant of the Jacobian or an unbiased - stochastic estimate thereof is calculated. - kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. - - Example: - >>> t = FFJTransform(3, 4) - >>> t - FFJTransform( - (ode): MLP( - (0): Linear(in_features=13, out_features=64, bias=True) - (1): ELU(alpha=1.0) - (2): Linear(in_features=64, out_features=64, bias=True) - (3): ELU(alpha=1.0) - (4): Linear(in_features=64, out_features=3, bias=True) - ) - ) - >>> x = torch.randn(3) - >>> x - tensor([ 0.1777, 1.0139, -1.0370]) - >>> c = torch.randn(4) - >>> y = t(c)(x) - >>> t(c).inv(y) - tensor([ 0.1777, 1.0139, -1.0370]) - """ - - def __init__( - self, - features: int, - context: int = 0, - freqs: int = 3, - exact: bool = True, - **kwargs, - ): - super().__init__() - - kwargs.setdefault('activation', nn.ELU) - - self.ode = MLP(features + context + 2 * freqs, features, **kwargs) - - self.register_buffer('times', torch.tensor((0.0, 1.0))) - self.register_buffer('freqs', torch.arange(1, freqs + 1) * pi) - - self.exact = exact - - def f(self, t: Tensor, x: Tensor, c: Tensor = None) -> Tensor: - t = self.freqs * t[..., None] - t = torch.cat((t.cos(), t.sin()), dim=-1) - - if c is None: - x = torch.cat(broadcast(t, x, ignore=1), dim=-1) - else: - x = torch.cat(broadcast(t, x, c, ignore=1), dim=-1) - - return self.ode(x) - - def forward(self, c: Tensor = None) -> Transform: - return FreeFormJacobianTransform( - f=partial(self.f, c=c), - t0=self.times[0], - t1=self.times[1], - phi=self.parameters() if c is None else (c, *self.parameters()), - exact=self.exact, - ) - - -class CNF(FlowModule): - r"""Creates a continuous normalizing flow (CNF) with free-form Jacobian - transformations. - - References: - | Neural Ordinary Differential Equations (Chen el al., 2018) - | https://arxiv.org/abs/1806.07366 - - | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) - | https://arxiv.org/abs/1810.01367 - - Arguments: - features: The number of features. - context: The number of context features. - kwargs: Keyword arguments passed to :class:`FFJTransform`. - """ - - def __init__( - self, - features: int, - context: int = 0, - **kwargs, - ): - transforms = [ - FFJTransform( - features=features, - context=context, - **kwargs, - ) - ] - - base = Unconditional( - DiagNormal, - torch.zeros(features), - torch.ones(features), - buffer=True, - ) - - super().__init__(transforms, base) diff --git a/zuko/flows/__init__.py b/zuko/flows/__init__.py new file mode 100644 index 0000000..6935934 --- /dev/null +++ b/zuko/flows/__init__.py @@ -0,0 +1,10 @@ +r"""Parameterized flows and transformations.""" + +from .autoregressive import * +from .continuous import * +from .coupling import * +from .core import * +from .gaussianization import * +from .mixture import * +from .neural import * +from .special import * diff --git a/zuko/flows/autoregressive.py b/zuko/flows/autoregressive.py new file mode 100644 index 0000000..65f7472 --- /dev/null +++ b/zuko/flows/autoregressive.py @@ -0,0 +1,225 @@ +r"""Autoregressive flows and transformations.""" + +__all__ = [ + 'MaskedAutoregressiveTransform', + 'MAF', +] + +import torch +import torch.nn as nn + +from functools import partial +from math import ceil, prod +from torch import Tensor, LongTensor, Size +from torch.distributions import Transform +from typing import * + +from .core import * +from ..distributions import * +from ..transforms import * +from ..nn import MaskedMLP +from ..utils import broadcast, unpack + + +class MaskedAutoregressiveTransform(TransformModule): + r"""Creates a masked autoregressive transformation. + + References: + | Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017) + | https://arxiv.org/abs/1705.07057 + + Arguments: + features: The number of features. + context: The number of context features. + passes: The number of sequential passes for the inverse transformation. If + :py:`None`, use the number of features instead, making the transformation + fully autoregressive. Coupling corresponds to :py:`passes=2`. + order: The feature ordering. If :py:`None`, use :py:`range(features)` instead. + univariate: The univariate transformation constructor. + shapes: The shapes of the univariate transformation parameters. + kwargs: Keyword arguments passed to :class:`zuko.nn.MaskedMLP`. + + Example: + >>> t = MaskedAutoregressiveTransform(3, 4) + >>> t + MaskedAutoregressiveTransform( + (base): MonotonicAffineTransform() + (order): [0, 1, 2] + (hyper): MaskedMLP( + (0): MaskedLinear(in_features=7, out_features=64, bias=True) + (1): ReLU() + (2): MaskedLinear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): MaskedLinear(in_features=64, out_features=6, bias=True) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([-0.9485, 1.5290, 0.2018]) + >>> c = torch.randn(4) + >>> y = t(c)(x) + >>> t(c).inv(y) + tensor([-0.9485, 1.5290, 0.2018]) + """ + + def __init__( + self, + features: int, + context: int = 0, + passes: int = None, + order: LongTensor = None, + univariate: Callable[..., Transform] = MonotonicAffineTransform, + shapes: Sequence[Size] = ((), ()), + **kwargs, + ): + super().__init__() + + # Univariate transformation + self.univariate = univariate + self.shapes = shapes + self.total = sum(prod(s) for s in shapes) + + # Adjacency + self.register_buffer('order', None) + + if passes is None: + passes = features + + if order is None: + order = torch.arange(features) + else: + order = torch.as_tensor(order) + + self.passes = min(max(passes, 1), features) + self.order = torch.div(order, ceil(features / self.passes), rounding_mode='floor') + + in_order = torch.cat((self.order, torch.full((context,), -1))) + out_order = torch.repeat_interleave(self.order, self.total) + adjacency = out_order[:, None] > in_order + + # Hyper network + self.hyper = MaskedMLP(adjacency, **kwargs) + + def extra_repr(self) -> str: + base = self.univariate(*map(torch.randn, self.shapes)) + order = self.order.tolist() + + if len(order) > 10: + order = order[:5] + [...] + order[-5:] + order = str(order).replace('Ellipsis', '...') + + return '\n'.join([ + f'(base): {base}', + f'(order): {order}', + ]) + + def meta(self, c: Tensor, x: Tensor) -> Transform: + if c is not None: + x = torch.cat(broadcast(x, c, ignore=1), dim=-1) + + phi = self.hyper(x) + phi = phi.unflatten(-1, (-1, self.total)) + phi = unpack(phi, self.shapes) + + return DependentTransform(self.univariate(*phi), 1) + + def forward(self, c: Tensor = None) -> Transform: + return AutoregressiveTransform(partial(self.meta, c), self.passes) + + +class MAF(FlowModule): + r"""Creates a masked autoregressive flow (MAF). + + References: + | Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017) + | https://arxiv.org/abs/1705.07057 + + Arguments: + features: The number of features. + context: The number of context features. + transforms: The number of autoregressive transformations. + randperm: Whether features are randomly permuted between transformations or not. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`. + + Example: + >>> flow = MAF(3, 4, transforms=3) + >>> flow + MAF( + (transforms): ModuleList( + (0): MaskedAutoregressiveTransform( + (base): MonotonicAffineTransform() + (order): [0, 1, 2] + (hyper): MaskedMLP( + (0): MaskedLinear(in_features=7, out_features=64, bias=True) + (1): ReLU() + (2): MaskedLinear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): MaskedLinear(in_features=64, out_features=6, bias=True) + ) + ) + (1): MaskedAutoregressiveTransform( + (base): MonotonicAffineTransform() + (order): [2, 1, 0] + (hyper): MaskedMLP( + (0): MaskedLinear(in_features=7, out_features=64, bias=True) + (1): ReLU() + (2): MaskedLinear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): MaskedLinear(in_features=64, out_features=6, bias=True) + ) + ) + (2): MaskedAutoregressiveTransform( + (base): MonotonicAffineTransform() + (order): [0, 1, 2] + (hyper): MaskedMLP( + (0): MaskedLinear(in_features=7, out_features=64, bias=True) + (1): ReLU() + (2): MaskedLinear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): MaskedLinear(in_features=64, out_features=6, bias=True) + ) + ) + ) + (base): DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])) + ) + >>> c = torch.randn(4) + >>> x = flow(c).sample() + >>> x + tensor([-1.7154, -0.4401, 0.7505]) + >>> flow(c).log_prob(x) + tensor(-4.4630, grad_fn=) + """ + + def __init__( + self, + features: int, + context: int = 0, + transforms: int = 3, + randperm: bool = False, + **kwargs, + ): + orders = [ + torch.arange(features), + torch.flipud(torch.arange(features)), + ] + + transforms = [ + MaskedAutoregressiveTransform( + features=features, + context=context, + order=torch.randperm(features) if randperm else orders[i % 2], + **kwargs, + ) + for i in range(transforms) + ] + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(transforms, base) diff --git a/zuko/flows/continuous.py b/zuko/flows/continuous.py new file mode 100644 index 0000000..01742e3 --- /dev/null +++ b/zuko/flows/continuous.py @@ -0,0 +1,138 @@ +r"""Continuous flows and transformations.""" + +__all__ = [ + 'FFJTransform', + 'CNF', +] + +import torch +import torch.nn as nn + +from functools import partial +from math import pi +from torch import Tensor +from torch.distributions import Transform +from typing import * + +from .core import * +from ..distributions import * +from ..transforms import * +from ..nn import MLP +from ..utils import broadcast + + +class FFJTransform(TransformModule): + r"""Creates a free-form Jacobian (FFJ) transformation. + + References: + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + features: The number of features. + context: The number of context features. + freqs: The number of time embedding frequencies. + exact: Whether the exact log-determinant of the Jacobian or an unbiased + stochastic estimate thereof is calculated. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. + + Example: + >>> t = FFJTransform(3, 4) + >>> t + FFJTransform( + (ode): MLP( + (0): Linear(in_features=13, out_features=64, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=64, out_features=64, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=64, out_features=3, bias=True) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([ 0.1777, 1.0139, -1.0370]) + >>> c = torch.randn(4) + >>> y = t(c)(x) + >>> t(c).inv(y) + tensor([ 0.1777, 1.0139, -1.0370]) + """ + + def __init__( + self, + features: int, + context: int = 0, + freqs: int = 3, + exact: bool = True, + **kwargs, + ): + super().__init__() + + kwargs.setdefault('activation', nn.ELU) + + self.ode = MLP(features + context + 2 * freqs, features, **kwargs) + + self.register_buffer('times', torch.tensor((0.0, 1.0))) + self.register_buffer('freqs', torch.arange(1, freqs + 1) * pi) + + self.exact = exact + + def f(self, t: Tensor, x: Tensor, c: Tensor = None) -> Tensor: + t = self.freqs * t[..., None] + t = torch.cat((t.cos(), t.sin()), dim=-1) + + if c is None: + x = torch.cat(broadcast(t, x, ignore=1), dim=-1) + else: + x = torch.cat(broadcast(t, x, c, ignore=1), dim=-1) + + return self.ode(x) + + def forward(self, c: Tensor = None) -> Transform: + return FreeFormJacobianTransform( + f=partial(self.f, c=c), + t0=self.times[0], + t1=self.times[1], + phi=self.parameters() if c is None else (c, *self.parameters()), + exact=self.exact, + ) + + +class CNF(FlowModule): + r"""Creates a continuous normalizing flow (CNF) with free-form Jacobian + transformations. + + References: + | Neural Ordinary Differential Equations (Chen el al., 2018) + | https://arxiv.org/abs/1806.07366 + + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + features: The number of features. + context: The number of context features. + kwargs: Keyword arguments passed to :class:`FFJTransform`. + """ + + def __init__( + self, + features: int, + context: int = 0, + **kwargs, + ): + transforms = [ + FFJTransform( + features=features, + context=context, + **kwargs, + ) + ] + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(transforms, base) diff --git a/zuko/flows/core.py b/zuko/flows/core.py new file mode 100644 index 0000000..637c878 --- /dev/null +++ b/zuko/flows/core.py @@ -0,0 +1,155 @@ +r"""Core flow building blocks.""" + +__all__ = [ + 'DistributionModule', + 'TransformModule', + 'FlowModule', + 'Unconditional', +] + +import abc +import torch +import torch.nn as nn + +from torch import Tensor +from torch.distributions import Distribution, Transform +from typing import * + +from ..distributions import * +from ..transforms import * + + +class DistributionModule(nn.Module, abc.ABC): + r"""Abstract distribution module. + + A distribution module can be seen as a recipe to build a distribution + :math:`p(X | c)`, given a context :math:`c`. + """ + + @abc.abstractmethod + def forward(self, c: Tensor = None) -> Distribution: + r""" + Arguments: + c: A context :math:`c`. + + Returns: + A distribution :math:`p(X | c)`. + """ + + pass + + +class TransformModule(nn.Module, abc.ABC): + r"""Abstract transformation module. + + A transformation module can be seen as a recipe to build a transformation + :math:`y = f(x; c)`, given a context :math:`c`. + """ + + @abc.abstractmethod + def forward(self, c: Tensor = None) -> Transform: + r""" + Arguments: + c: A context :math:`c`. + + Returns: + A transformation :math:`y = f(x | c)`. + """ + + pass + + +class FlowModule(DistributionModule): + r"""Creates a normalizing flow module. + + Arguments: + transforms: A list of transformation modules. + base: A distribution module. + """ + + def __init__( + self, + transforms: Sequence[TransformModule], + base: DistributionModule, + ): + super().__init__() + + self.transforms = nn.ModuleList(transforms) + self.base = base + + def forward(self, c: Tensor = None) -> NormalizingFlow: + r""" + Arguments: + c: A context :math:`c`. + + Returns: + A normalizing flow :math:`p(X | c)`. + """ + + transform = ComposedTransform(*(t(c) for t in self.transforms)) + + if c is None: + base = self.base(c) + else: + base = self.base(c).expand(c.shape[:-1]) + + return NormalizingFlow(transform, base) + + +class Unconditional(nn.Module): + r"""Creates a module that delays the evaluation of a recipe. + + Typically, the recipe returns an unconditional distribution or transformation to be + part of a :class:`FlowModule`. The positional arguments of the recipe are registered + as buffers or parameters. + + Arguments: + recipe: An arbitrary function. + args: The positional tensor arguments passed to `recipe`. + buffer: Whether tensors are registered as buffers or parameters. + kwargs: The keyword arguments passed to `recipe`. + + Examples: + >>> mu, sigma = torch.zeros(3), torch.ones(3) + >>> d = Unconditional(DiagNormal, mu, sigma, buffer=True) + >>> d() + DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])) + >>> d().sample() + tensor([-0.6687, -0.9690, 1.7461]) + + >>> t = Unconditional(ExpTransform) + >>> t() + ExpTransform() + >>> x = torch.randn(3) + >>> t()(x) + tensor([0.5523, 0.7997, 0.9189]) + """ + + def __init__( + self, + recipe: Callable[..., Any], + *args: Tensor, + buffer: bool = False, + **kwargs, + ): + super().__init__() + + self.recipe = recipe + + for i, arg in enumerate(args): + if buffer: + self.register_buffer(f'_{i}', arg) + else: + self.register_parameter(f'_{i}', nn.Parameter(arg)) + + self.kwargs = kwargs + + def __repr__(self) -> str: + return repr(self.forward()) + + def forward(self, c: Tensor = None) -> Any: + return self.recipe( + *self._parameters.values(), + *self._buffers.values(), + **self.kwargs, + ) diff --git a/zuko/flows/coupling.py b/zuko/flows/coupling.py new file mode 100644 index 0000000..54a37be --- /dev/null +++ b/zuko/flows/coupling.py @@ -0,0 +1,169 @@ +r"""Coupling flows and transformations.""" + +__all__ = [ + 'GeneralCouplingTransform', + 'NICE', +] + +import torch + +from functools import partial +from math import prod +from torch import Tensor, BoolTensor, Size +from torch.distributions import Transform +from typing import * + +from .core import * +from ..distributions import * +from ..transforms import * +from ..nn import MLP +from ..utils import broadcast, unpack + + +class GeneralCouplingTransform(TransformModule): + r"""Creates a general coupling transformation. + + References: + | NICE: Non-linear Independent Components Estimation (Dinh et al., 2014) + | https://arxiv.org/abs/1410.8516 + + Arguments: + features: The number of features. + context: The number of context features. + mask: The coupling mask. If :py:`None`, use a checkered mask. + univariate: The univariate transformation constructor. + shapes: The shapes of the univariate transformation parameters. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. + + Example: + >>> t = GeneralCouplingTransform(3, 4) + >>> t + GeneralCouplingTransform( + (base): MonotonicAffineTransform() + (mask): [0, 1, 0] + (hyper): MLP( + (0): Linear(in_features=5, out_features=64, bias=True) + (1): ReLU() + (2): Linear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): Linear(in_features=64, out_features=4, bias=True) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([-0.8743, 0.6232, 1.2439]) + >>> c = torch.randn(4) + >>> y = t(c)(x) + >>> t(c).inv(y) + tensor([-0.8743, 0.6232, 1.2439]) + """ + + def __init__( + self, + features: int, + context: int = 0, + mask: BoolTensor = None, + univariate: Callable[..., Transform] = MonotonicAffineTransform, + shapes: Sequence[Size] = ((), ()), + **kwargs, + ): + super().__init__() + + # Univariate transformation + self.univariate = univariate + self.shapes = shapes + self.total = sum(prod(s) for s in shapes) + + # Mask + self.register_buffer('mask', None) + + if mask is None: + self.mask = torch.arange(features) % 2 == 1 + else: + self.mask = mask + + features_a = self.mask.sum().item() + features_b = features - features_a + + # Hyper network + self.hyper = MLP(features_a + context, features_b * self.total, **kwargs) + + def extra_repr(self) -> str: + base = self.univariate(*map(torch.randn, self.shapes)) + mask = self.mask.int().tolist() + + if len(mask) > 10: + mask = mask[:5] + [...] + mask[-5:] + mask = str(mask).replace('Ellipsis', '...') + + return '\n'.join([ + f'(base): {base}', + f'(mask): {mask}', + ]) + + def meta(self, c: Tensor, x: Tensor) -> Transform: + if c is not None: + x = torch.cat(broadcast(x, c, ignore=1), dim=-1) + + phi = self.hyper(x) + phi = phi.unflatten(-1, (-1, self.total)) + phi = unpack(phi, self.shapes) + + return DependentTransform(self.univariate(*phi), 1) + + def forward(self, c: Tensor = None) -> Transform: + return CouplingTransform(partial(self.meta, c), self.mask) + + +class NICE(FlowModule): + r"""Creates a NICE flow. + + Affine transformations are used by default, instead of the additive transformations + used by Dinh et al. (2014) originally. + + References: + | NICE: Non-linear Independent Components Estimation (Dinh et al., 2014) + | https://arxiv.org/abs/1410.8516 + + Arguments: + features: The number of features. + context: The number of context features. + transforms: The number of coupling transformations. + randmask: Whether random coupling masks are used or not. If :py:`False`, + use alternating checkered masks. + kwargs: Keyword arguments passed to :class:`GeneralCouplingTransform`. + """ + + def __init__( + self, + features: int, + context: int = 0, + transforms: int = 3, + randmask: bool = False, + **kwargs, + ): + temp = [] + + for i in range(transforms): + if randmask: + mask = torch.randperm(features) % 2 == i % 2 + else: + mask = torch.arange(features) % 2 == i % 2 + + temp.append( + GeneralCouplingTransform( + features=features, + context=context, + mask=mask, + **kwargs, + ) + ) + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(temp, base) diff --git a/zuko/flows/gaussianization.py b/zuko/flows/gaussianization.py new file mode 100644 index 0000000..95ce36b --- /dev/null +++ b/zuko/flows/gaussianization.py @@ -0,0 +1,142 @@ +r"""Gaussianization flows.""" + +__all__ = [ + 'ElementWiseTransform', + 'GF', +] + +import torch +import torch.nn as nn + +from math import prod +from torch import Tensor, Size +from torch.distributions import Transform +from typing import * + +from .core import * +from ..distributions import * +from ..transforms import * +from ..nn import MLP +from ..utils import unpack + + +class ElementWiseTransform(TransformModule): + r"""Creates an element-wise transformation. + + Arguments: + features: The number of features. + context: The number of context features. + univariate: The univariate transformation constructor. + shapes: The shapes of the univariate transformation parameters. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. + + Example: + >>> t = ElementWiseTransform(3, 4) + >>> t + ElementWiseTransform( + (base): MonotonicAffineTransform() + (hyper): MLP( + (0): Linear(in_features=4, out_features=64, bias=True) + (1): ReLU() + (2): Linear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): Linear(in_features=64, out_features=6, bias=True) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([2.1983, -1.3182, 0.0329]) + >>> c = torch.randn(4) + >>> y = t(c)(x) + >>> t(c).inv(y) + tensor([2.1983, -1.3182, 0.0329]) + """ + + def __init__( + self, + features: int, + context: int = 0, + univariate: Callable[..., Transform] = MonotonicAffineTransform, + shapes: Sequence[Size] = ((), ()), + **kwargs, + ): + super().__init__() + + self.univariate = univariate + self.shapes = shapes + self.total = sum(prod(s) for s in shapes) + + if context > 0: + self.hyper = MLP(context, features * self.total, **kwargs) + else: + self.phi = nn.ParameterList(torch.randn(features, *s) for s in shapes) + + def extra_repr(self) -> str: + base = self.univariate(*map(torch.randn, self.shapes)) + + return '\n'.join([ + f'(base): {base}', + ]) + + def forward(self, c: Tensor = None) -> Transform: + if c is None: + phi = self.phi + else: + phi = self.hyper(c) + phi = phi.unflatten(-1, (-1, self.total)) + phi = unpack(phi, self.shapes) + + return DependentTransform(self.univariate(*phi), 1) + + +class GF(FlowModule): + r"""Creates a gaussianization flow (GF). + + References: + | Gaussianization Flows (Meng et al., 2020) + | https://arxiv.org/abs/2003.01941 + + Arguments: + features: The number of features. + context: The number of context features. + transforms: The number of coupling transformations. + components: The number of mixture components in each transformation. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. + """ + + def __init__( + self, + features: int, + context: int = 0, + transforms: int = 3, + components: int = 8, + **kwargs, + ): + transforms = [ + ElementWiseTransform( + features=features, + context=context, + univariate=GaussianizationTransform, + shapes=[(components,), (components,)], + **kwargs, + ) + for _ in range(transforms) + ] + + for i in reversed(range(len(transforms))): + transforms.insert( + i + 1, + Unconditional( + RotationTransform, + torch.randn(features, features), + ), + ) + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(transforms, base) diff --git a/zuko/flows/mixture.py b/zuko/flows/mixture.py new file mode 100644 index 0000000..b9118b4 --- /dev/null +++ b/zuko/flows/mixture.py @@ -0,0 +1,74 @@ +r"""Mixture models.""" + +__all__ = [ + 'GMM', +] + +import torch +import torch.nn as nn + +from math import prod +from torch import Tensor +from torch.distributions import * +from typing import * + +from .core import * +from ..distributions import * +from ..transforms import * +from ..nn import MLP +from ..utils import unpack + + +class GMM(DistributionModule): + r"""Creates a Gaussian mixture model (GMM). + + .. math:: p(X | c) = \sum_{i = 1}^K w_i(c) \, \mathcal{N}(X | \mu_i(c), \Sigma_i(c)) + + Wikipedia: + https://wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model + + Arguments: + features: The number of features. + context: The number of context features. + components: The number of components :math:`K` in the mixture. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. + """ + + def __init__( + self, + features: int, + context: int = 0, + components: int = 2, + **kwargs, + ): + super().__init__() + + shapes = [ + (components,), # probabilities + (components, features), # mean + (components, features), # diagonal + (components, features * (features - 1) // 2), # off diagonal + ] + + self.shapes = shapes + self.total = sum(prod(s) for s in shapes) + + if context > 0: + self.hyper = MLP(context, self.total, **kwargs) + else: + self.phi = nn.ParameterList(torch.randn(*s) for s in shapes) + + def forward(self, c: Tensor = None) -> Distribution: + if c is None: + phi = self.phi + else: + phi = self.hyper(c) + phi = unpack(phi, self.shapes) + + logits, loc, diag, tril = phi + + scale = torch.diag_embed(diag.exp() + 1e-5) + mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1) + scale = torch.masked_scatter(scale, mask, tril) + + return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits) diff --git a/zuko/flows/neural.py b/zuko/flows/neural.py new file mode 100644 index 0000000..1b3332f --- /dev/null +++ b/zuko/flows/neural.py @@ -0,0 +1,287 @@ +r"""Neural flows and transformations.""" + +__all__ = [ + 'NeuralAutoregressiveTransform', + 'NAF', + 'UnconstrainedNeuralAutoregressiveTransform', + 'UNAF', +] + +import torch +import torch.nn as nn + +from functools import partial +from torch import Tensor +from torch.distributions import Transform +from typing import * + +from .autoregressive import * +from .core import * +from ..distributions import * +from ..transforms import * +from ..nn import MLP, MonotonicMLP +from ..utils import broadcast + + +class NeuralAutoregressiveTransform(MaskedAutoregressiveTransform): + r"""Creates a neural autoregressive transformation. + + The monotonic neural network is parametrized by its internal positive weights, + which are independent of the features and context. To modulate its behavior, it + receives as input a signal that is autoregressively dependent on the features + and context. + + References: + | Neural Autoregressive Flows (Huang et al., 2018) + | https://arxiv.org/abs/1804.00779 + + Arguments: + features: The number of features. + context: The number of context features. + signal: The number of signal features of the monotonic network. + network: Keyword arguments passed to :class:`zuko.nn.MonotonicMLP`. + kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`. + + Example: + >>> t = NeuralAutoregressiveTransform(3, 4) + >>> t + NeuralAutoregressiveTransform( + (base): MonotonicTransform() + (order): [0, 1, 2] + (hyper): MaskedMLP( + (0): MaskedLinear(in_features=7, out_features=64, bias=True) + (1): ReLU() + (2): MaskedLinear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): MaskedLinear(in_features=64, out_features=24, bias=True) + ) + (network): MonotonicMLP( + (0): MonotonicLinear(in_features=9, out_features=64, bias=True) + (1): TwoWayELU(alpha=1.0) + (2): MonotonicLinear(in_features=64, out_features=64, bias=True) + (3): TwoWayELU(alpha=1.0) + (4): MonotonicLinear(in_features=64, out_features=1, bias=True) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([-2.3267, 1.4581, -1.6776]) + >>> c = torch.randn(4) + >>> y = t(c)(x) + >>> t(c).inv(y) + tensor([-2.3267, 1.4581, -1.6776]) + """ + + def __init__( + self, + features: int, + context: int = 0, + signal: int = 8, + network: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__( + features=features, + context=context, + univariate=self.univariate, + shapes=[(signal,)], + **kwargs, + ) + + self.network = MonotonicMLP(1 + signal, 1, **network) + + def f(self, signal: Tensor, x: Tensor) -> Tensor: + return self.network( + torch.cat(broadcast(x[..., None], signal, ignore=1), dim=-1) + ).squeeze(dim=-1) + + def univariate(self, signal: Tensor) -> Transform: + return MonotonicTransform( + f=partial(self.f, signal), + phi=(signal, *self.network.parameters()), + ) + + +class NAF(FlowModule): + r"""Creates a neural autoregressive flow (NAF). + + References: + | Neural Autoregressive Flows (Huang et al., 2018) + | https://arxiv.org/abs/1804.00779 + + Arguments: + features: The number of features. + context: The number of context features. + transforms: The number of autoregressive transformations. + randperm: Whether features are randomly permuted between transformations or not. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + unconstrained: Whether to use unconstrained or regular monotonic networks. + kwargs: Keyword arguments passed to :class:`NeuralAutoregressiveTransform`. + """ + + def __init__( + self, + features: int, + context: int = 0, + transforms: int = 3, + randperm: bool = False, + **kwargs, + ): + orders = [ + torch.arange(features), + torch.flipud(torch.arange(features)), + ] + + transforms = [ + NeuralAutoregressiveTransform( + features=features, + context=context, + order=torch.randperm(features) if randperm else orders[i % 2], + **kwargs, + ) + for i in range(transforms) + ] + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(transforms, base) + + +class UnconstrainedNeuralAutoregressiveTransform(MaskedAutoregressiveTransform): + r"""Creates an unconstrained neural autoregressive transformation. + + The integrand neural network is parametrized by its internal weights, which are + independent of the features and context. To modulate its behavior, it receives as + input a signal that is autoregressively dependent on the features and context. The + integration constant has the same dependencies as the signal. + + References: + | Unconstrained Monotonic Neural Networks (Wehenkel et al., 2019) + | https://arxiv.org/abs/1908.05164 + + Arguments: + features: The number of features. + context: The number of context features. + signal: The number of signal features of the integrand network. + network: Keyword arguments passed to :class:`zuko.nn.MLP`. + kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`. + + Example: + >>> t = UnconstrainedNeuralAutoregressiveTransform(3, 4) + >>> t + UnconstrainedNeuralAutoregressiveTransform( + (base): UnconstrainedMonotonicTransform() + (order): [0, 1, 2] + (hyper): MaskedMLP( + (0): MaskedLinear(in_features=7, out_features=64, bias=True) + (1): ReLU() + (2): MaskedLinear(in_features=64, out_features=64, bias=True) + (3): ReLU() + (4): MaskedLinear(in_features=64, out_features=27, bias=True) + ) + (integrand): MLP( + (0): Linear(in_features=9, out_features=64, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=64, out_features=64, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=64, out_features=1, bias=True) + (5): Softplus(beta=1, threshold=20) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([-0.0103, -1.0871, -0.0667]) + >>> c = torch.randn(4) + >>> y = t(c)(x) + >>> t(c).inv(y) + tensor([-0.0103, -1.0871, -0.0667]) + """ + + def __init__( + self, + features: int, + context: int = 0, + signal: int = 8, + network: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__( + features=features, + context=context, + univariate=self.univariate, + shapes=[(signal,), ()], + **kwargs, + ) + + network.setdefault('activation', nn.ELU) + + self.integrand = MLP(1 + signal, 1, **network) + self.integrand.append(nn.Softplus()) + + def g(self, signal: Tensor, x: Tensor) -> Tensor: + return self.integrand( + torch.cat(broadcast(x[..., None], signal, ignore=1), dim=-1) + ).squeeze(dim=-1) + + def univariate(self, signal: Tensor, constant: Tensor) -> Transform: + return UnconstrainedMonotonicTransform( + g=partial(self.g, signal), + C=constant, + phi=(signal, *self.integrand.parameters()), + ) + + +class UNAF(FlowModule): + r"""Creates an unconstrained neural autoregressive flow (UNAF). + + References: + | Unconstrained Monotonic Neural Networks (Wehenkel et al., 2019) + | https://arxiv.org/abs/1908.05164 + + Arguments: + features: The number of features. + context: The number of context features. + transforms: The number of autoregressive transformations. + randperm: Whether features are randomly permuted between transformations or not. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + kwargs: Keyword arguments passed to :class:`UnconstrainedNeuralAutoregressiveTransform`. + """ + + def __init__( + self, + features: int, + context: int = 0, + transforms: int = 3, + randperm: bool = False, + **kwargs, + ): + orders = [ + torch.arange(features), + torch.flipud(torch.arange(features)), + ] + + transforms = [ + UnconstrainedNeuralAutoregressiveTransform( + features=features, + context=context, + order=torch.randperm(features) if randperm else orders[i % 2], + **kwargs, + ) + for i in range(transforms) + ] + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(transforms, base) diff --git a/zuko/flows/special.py b/zuko/flows/special.py new file mode 100644 index 0000000..7eaed4d --- /dev/null +++ b/zuko/flows/special.py @@ -0,0 +1,127 @@ +r"""Special flows.""" + +__all__ = [ + 'NSF', + 'NCSF', + 'SOSPF', +] + +import torch + +from math import pi +from torch import Tensor +from torch.distributions import * +from typing import * + +from .autoregressive import MAF +from .core import * +from ..distributions import * +from ..transforms import * + + +class NSF(MAF): + r"""Creates a neural spline flow (NSF) with monotonic rational-quadratic spline + transformations. + + By default, transformations are fully autoregressive. Coupling transformations can + be obtained by setting :py:`passes=2`. + + References: + | Neural Spline Flows (Durkan et al., 2019) + | https://arxiv.org/abs/1906.04032 + + Arguments: + features: The number of features. + context: The number of context features. + bins: The number of bins :math:`K`. + kwargs: Keyword arguments passed to :class:`zuko.flows.autoregressive.MAF`. + """ + + def __init__( + self, + features: int, + context: int = 0, + bins: int = 8, + **kwargs, + ): + super().__init__( + features=features, + context=context, + univariate=MonotonicRQSTransform, + shapes=[(bins,), (bins,), (bins - 1,)], + **kwargs, + ) + + +class NCSF(NSF): + r"""Creates a neural circular spline flow (NCSF). + + Note: + Features are assumed to lie in the half-open interval :math:`[-\pi, \pi[`. + + References: + | Normalizing Flows on Tori and Spheres (Rezende et al., 2020) + | https://arxiv.org/abs/2002.02428 + + Arguments: + features: The number of features. + context: The number of context features. + kwargs: Keyword arguments passed to :class:`NSF`. + """ + + def __init__( + self, + features: int, + context: int = 0, + **kwargs, + ): + super().__init__(features, context, **kwargs) + + for t in self.transforms: + t.univariate = self.circular_spline + + self.base = Unconditional( + BoxUniform, + torch.full((features,), -pi - 1e-5), + torch.full((features,), pi + 1e-5), + buffer=True, + ) + + @staticmethod + def circular_spline(*args) -> Transform: + return ComposedTransform( + CircularShiftTransform(bound=pi), + MonotonicRQSTransform(*args, bound=pi), + ) + + +class SOSPF(MAF): + r"""Creates a sum-of-squares polynomial flow (SOSPF). + + References: + | Sum-of-Squares Polynomial Flow (Jaini et al., 2019) + | https://arxiv.org/abs/1905.02325 + + Arguments: + features: The number of features. + context: The number of context features. + degree: The degree :math:`L` of polynomials. + polynomials: The number of polynomials :math:`K`. + kwargs: Keyword arguments passed to :class:`zuko.flows.autoregressive.MAF`. + """ + + def __init__( + self, + features: int, + context: int = 0, + degree: int = 3, + polynomials: int = 2, + **kwargs, + ): + super().__init__( + features=features, + context=context, + univariate=SOSPolynomialTransform, + shapes=[(polynomials, degree + 1), ()], + **kwargs, + ) diff --git a/zuko/transforms.py b/zuko/transforms.py index 4886d08..6174727 100644 --- a/zuko/transforms.py +++ b/zuko/transforms.py @@ -11,13 +11,14 @@ 'MonotonicAffineTransform', 'MonotonicRQSTransform', 'MonotonicTransform', + 'GaussianizationTransform', 'UnconstrainedMonotonicTransform', 'SOSPolynomialTransform', - 'FreeFormJacobianTransform', 'AutoregressiveTransform', 'CouplingTransform', - 'LULinearTransform', + 'FreeFormJacobianTransform', 'PermutationTransform', + 'RotationTransform', ] import math @@ -287,15 +288,15 @@ class SoftclipTransform(Transform): bound: The codomain bound :math:`B`. """ - domain = constraints.real - codomain = constraints.real bijective = True sign = +1 - def __init__(self, bound: float = 5.0, **kwargs): + def __init__(self, bound: float = 1.0, **kwargs): super().__init__(**kwargs) self.bound = bound + self.domain = constraints.real + self.codomain = constraints.interval(-bound, bound) def __repr__(self) -> str: return f'{self.__class__.__name__}(bound={self.bound})' @@ -323,14 +324,14 @@ class CircularShiftTransform(Transform): bound: The domain bound :math:`B`. """ - domain = constraints.real - codomain = constraints.real bijective = True - def __init__(self, bound: float = 5.0, **kwargs): + def __init__(self, bound: float = 1.0, **kwargs): super().__init__(**kwargs) self.bound = bound + self.domain = constraints.interval(-bound, bound) + self.codomain = constraints.interval(-bound, bound) def __repr__(self) -> str: return f'{self.__class__.__name__}(bound={self.bound})' @@ -346,11 +347,11 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: class MonotonicAffineTransform(Transform): - r"""Creates a transformation :math:`f(x) = \alpha x + \beta`. + r"""Creates a transformation :math:`f(x) = \exp(a) x + b`. Arguments: - shift: The shift term :math:`\beta`, with shape :math:`(*,)`. - scale: The unconstrained scale factor :math:`\alpha`, with shape :math:`(*,)`. + shift: The shift term :math:`b`, with shape :math:`(*,)`. + scale: The unconstrained scale factor :math:`a`, with shape :math:`(*,)`. slope: The minimum slope of the transformation. """ @@ -530,7 +531,7 @@ def __init__( self, f: Callable[[Tensor], Tensor], phi: Iterable[Tensor] = (), - bound: float = 5.0, + bound: float = 10.0, eps: float = 1e-6, **kwargs, ): @@ -560,7 +561,7 @@ def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: with torch.enable_grad(): - x = x.requires_grad_() + x = x.view_as(x).requires_grad_() # shallow copy y = self.f(x) jacobian = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True)[0] @@ -568,6 +569,50 @@ def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: return y, jacobian.log() +class GaussianizationTransform(MonotonicTransform): + r"""Creates a gaussianization transformation. + + .. math:: f(x) = \Phi^{-1} + \left( \frac{1}{K} \sum_{i=1}^K \Phi(\exp(a_i) x + b_i) \right) + + where :math:`\Phi` is the cumulative distribution function (CDF) of the standard + normal :math:`\mathcal{N}(0, 1)`. + + References: + | Gaussianization (Chen et al., 2000) + | https://papers.nips.cc/paper/1856-gaussianization + + Arguments: + shift: The shift terms :math:`b`, with shape :math:`(*, K)`. + scale: The unconstrained scale factors :math:`a`, with shape :math:`(*, K)`. + kwargs: Keyword arguments passed to :class:`MonotonicTransform`. + """ + + domain = constraints.real + codomain = constraints.real + bijective = True + sign = +1 + + def __init__( + self, + shift: Tensor, + scale: Tensor, + **kwargs, + ): + super().__init__(self.f, phi=(shift, scale), **kwargs) + + self.shift = shift + self.scale = torch.exp(scale) + + def f(self, x: Tensor) -> Tensor: + y = self.scale * x[..., None] + self.shift + y = torch.erf(y / math.sqrt(2)) + y = torch.mean(y, dim=-1) + y = torch.erfinv(y) * math.sqrt(2) + + return y + + class UnconstrainedMonotonicTransform(MonotonicTransform): r"""Creates a monotonic transformation :math:`f(x)` by integrating a positive univariate function :math:`g(x)`. @@ -623,7 +668,7 @@ class SOSPolynomialTransform(UnconstrainedMonotonicTransform): The transformation :math:`f(x)` is expressed as the primitive integral of the sum of :math:`K` squared polynomials of degree :math:`L`. - .. math:: f(x) = \int_0^x \sum_{i = 1}^K + .. math:: f(x) = \int_0^x \frac{1}{K} \sum_{i = 1}^K \left( 1 + \sum_{j = 0}^L a_{i,j} ~ u^j \right)^2 ~ du + C References: @@ -652,106 +697,16 @@ def g(self, x: Tensor) -> Tensor: x = x[..., None] ** self.i p = 1 + self.a @ x[..., None] - return p.squeeze(dim=-1).square().sum(dim=-1) - - -class FreeFormJacobianTransform(Transform): - r"""Creates a free-form Jacobian transformation. - - The transformation is the integration of a system of first-order ordinary - differential equations - - .. math:: x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) ~ dt . - - References: - | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) - | https://arxiv.org/abs/1810.01367 - - Arguments: - f: A system of first-order ODEs :math:`f_\phi`. - t0: The initial integration time :math:`t_0`. - t1: The final integration time :math:`t_1`. - phi: The parameters :math:`\phi` of :math:`f_\phi`. - exact: Whether the exact log-determinant of the Jacobian or an unbiased - stochastic estimate thereof is calculated. - """ - - domain = constraints.real_vector - codomain = constraints.real_vector - bijective = True - - def __init__( - self, - f: Callable[[Tensor, Tensor], Tensor], - t0: Union[float, Tensor] = 0.0, - t1: Union[float, Tensor] = 1.0, - phi: Iterable[Tensor] = (), - exact: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - - self.f = f - self.t0 = t0 - self.t1 = t1 - self.phi = tuple(filter(lambda p: p.requires_grad, phi)) - self.exact = exact - self.trace_scale = 1e-2 # relax jacobian tolerances - - def _call(self, x: Tensor) -> Tensor: - return odeint(self.f, x, self.t0, self.t1, self.phi) - - @property - def inv(self) -> Transform: - return FreeFormJacobianTransform( - f=self.f, - t0=self.t1, - t1=self.t0, - phi=self.phi, - exact=self.exact, - ) - - def _inverse(self, y: Tensor) -> Tensor: - return odeint(self.f, y, self.t1, self.t0, self.phi) - - def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: - _, ladj = self.call_and_ladj(x) - return ladj - - def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: - if self.exact: - I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) - I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0) - else: - eps = torch.randn_like(x) - - def f_aug(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor: - with torch.enable_grad(): - x = x.requires_grad_() - dx = self.f(t, x) - - if self.exact: - jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0] - trace = torch.einsum('i...i', jacobian) - else: - epsjp = torch.autograd.grad(dx, x, eps, create_graph=True)[0] - trace = (epsjp * eps).sum(dim=-1) - - return dx, trace * self.trace_scale - - ladj = torch.zeros_like(x[..., 0]) - y, ladj = odeint(f_aug, (x, ladj), self.t0, self.t1, self.phi) - - return y, ladj * (1 / self.trace_scale) + return p.squeeze(dim=-1).square().mean(dim=-1) class AutoregressiveTransform(Transform): r"""Transform via an autoregressive scheme. - .. math:: y_i = f(x_i; x_{ Tuple[Tensor, Tensor]: return y, ladj -class LULinearTransform(Transform): - r"""Creates a transformation :math:`f(x) = L U x`. +class FreeFormJacobianTransform(Transform): + r"""Creates a free-form Jacobian transformation. + + The transformation is the integration of a system of first-order ordinary + differential equations + + .. math:: x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) ~ dt . + + References: + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 Arguments: - LU: A matrix whose lower and upper triangular parts are the non-zero elements - of :math:`L` and :math:`U`, with shape :math:`(*, D, D)`. + f: A system of first-order ODEs :math:`f_\phi`. + t0: The initial integration time :math:`t_0`. + t1: The final integration time :math:`t_1`. + phi: The parameters :math:`\phi` of :math:`f_\phi`. + exact: Whether the exact log-determinant of the Jacobian or an unbiased + stochastic estimate thereof is calculated. """ domain = constraints.real_vector codomain = constraints.real_vector bijective = True - def __init__(self, LU: Tensor, **kwargs): + def __init__( + self, + f: Callable[[Tensor, Tensor], Tensor], + t0: Union[float, Tensor] = 0.0, + t1: Union[float, Tensor] = 1.0, + phi: Iterable[Tensor] = (), + exact: bool = True, + **kwargs, + ): super().__init__(**kwargs) - I = torch.eye(LU.shape[-1], dtype=LU.dtype, device=LU.device) - - self.L = torch.tril(LU, diagonal=-1) + I - self.U = torch.triu(LU, diagonal=+1) + I + self.f = f + self.t0 = t0 + self.t1 = t1 + self.phi = tuple(filter(lambda p: p.requires_grad, phi)) + self.exact = exact + self.trace_scale = 1e-2 # relax jacobian tolerances def _call(self, x: Tensor) -> Tensor: - return (self.L @ self.U @ x.unsqueeze(-1)).squeeze(-1) + return odeint(self.f, x, self.t0, self.t1, self.phi) + + @property + def inv(self) -> Transform: + return FreeFormJacobianTransform( + f=self.f, + t0=self.t1, + t1=self.t0, + phi=self.phi, + exact=self.exact, + ) def _inverse(self, y: Tensor) -> Tensor: - return torch.linalg.solve_triangular( - self.U, - torch.linalg.solve_triangular( - self.L, - y.unsqueeze(-1), - upper=False, - unitriangular=True, - ), - upper=True, - unitriangular=True, - ).squeeze(-1) + return odeint(self.f, y, self.t1, self.t0, self.phi) def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: - return x.new_zeros(x.shape[:-1]) + _, ladj = self.call_and_ladj(x) + return ladj + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + if self.exact: + I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) + I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0) + else: + eps = torch.randn_like(x) + + def f_aug(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor: + with torch.enable_grad(): + x = x.view_as(x).requires_grad_() # shallow copy + dx = self.f(t, x) + + if self.exact: + jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0] + trace = torch.einsum('i...i', jacobian) + else: + epsjp = torch.autograd.grad(dx, x, eps, create_graph=True)[0] + trace = (epsjp * eps).sum(dim=-1) + + return dx, trace * self.trace_scale + + ladj = torch.zeros_like(x[..., 0]) + y, ladj = odeint(f_aug, (x, ladj), self.t0, self.t1, self.phi) + + return y, ladj * (1 / self.trace_scale) class PermutationTransform(Transform): @@ -925,3 +930,33 @@ def _inverse(self, y: Tensor) -> Tensor: def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: return x.new_zeros(x.shape[:-1]) + + +class RotationTransform(Transform): + r"""Creates a rotation transformation :math:`f(x) = R x`. + + .. math:: R = \exp(A - A^T) + + Because :math:`A - A^T` is skew-symmetric, :math:`R` is orthogonal. + + Arguments: + A: A square matrix :math:`A`, with shape :math:`(*, D, D)`. + """ + + domain = constraints.real_vector + codomain = constraints.real_vector + bijective = True + + def __init__(self, A: Tensor, **kwargs): + super().__init__(**kwargs) + + self.R = torch.linalg.matrix_exp(A - A.mT) + + def _call(self, x: Tensor) -> Tensor: + return x @ self.R + + def _inverse(self, y: Tensor) -> Tensor: + return y @ self.R.mT + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + return x.new_zeros(x.shape[:-1])