Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add neural score estimation (NSE) #7

Merged
merged 12 commits into from
Jan 17, 2023
286 changes: 239 additions & 47 deletions lampe/inference.py

Large diffs are not rendered by default.

31 changes: 15 additions & 16 deletions lampe/nn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
r"""Neural networks, layers and modules."""

__all__ = ['ResMLP']
__all__ = ['MLP', 'ResMLP']

import torch
import torch.nn as nn

from zuko.nn import MLP

from textwrap import indent
from torch import Tensor
from typing import *
from zuko.nn import MLP


class Affine(nn.Module):
Expand All @@ -20,13 +18,18 @@ class Affine(nn.Module):
Arguments:
shift: The shift term :math:`\beta`.
scale: The scale factor :math:`\alpha`.
trainable: Whether the layer is trainable or not.
"""

def __init__(self, shift: Tensor, scale: Tensor):
def __init__(self, shift: Tensor, scale: Tensor, trainable: bool = True):
super().__init__()

self.shift = nn.Parameter(torch.as_tensor(shift))
self.scale = nn.Parameter(torch.as_tensor(scale))
if trainable:
self.shift = nn.Parameter(torch.as_tensor(shift))
self.scale = nn.Parameter(torch.as_tensor(scale))
else:
self.register_buffer('shift', torch.as_tensor(shift))
self.register_buffer('scale', torch.as_tensor(scale))

def forward(self, x: Tensor) -> Tensor:
return x * self.scale + self.shift
Expand Down Expand Up @@ -56,15 +59,15 @@ def forward(self, x: Tensor) -> Tensor:
class ResMLP(nn.Sequential):
r"""Creates a residual multi-layer perceptron (ResMLP).

A ResMLP is a series of residual blocks where each block is a (shallow) MLP.
Using residual blocks instead of regular non-linear functions prevents the gradients
from vanishing, which allows for deeper networks.
A ResMLP is a series of residual blocks where each block is a (shallow) MLP. Using
residual blocks instead of regular non-linear functions prevents the gradients from
vanishing, which allows for deeper networks.

Arguments:
in_features: The number of input features.
out_features: The number of output features.
hidden_features: The numbers of hidden features.
kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`.
kwargs: Keyword arguments passed to :class:`MLP`.

Example:
>>> net = ResMLP(64, 1, [32, 16], activation=nn.ELU)
Expand All @@ -75,16 +78,12 @@ class ResMLP(nn.Sequential):
(0): Linear(in_features=32, out_features=32, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=32, out_features=32, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=32, out_features=32, bias=True)
))
(2): Linear(in_features=32, out_features=16, bias=True)
(3): Residual(MLP(
(0): Linear(in_features=16, out_features=16, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=16, out_features=16, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=16, out_features=16, bias=True)
))
(4): Linear(in_features=16, out_features=1, bias=True)
)
Expand All @@ -106,7 +105,7 @@ def __init__(
if after != before:
blocks.append(nn.Linear(before, after))

blocks.append(Residual(MLP(after, after, [after] * 2, **kwargs)))
blocks.append(Residual(MLP(after, after, [after], **kwargs)))

blocks = blocks[:-1]

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ matplotlib>=3.4.0
numpy>=1.20.0
torch>=1.8.0
tqdm>=4.52.0
zuko>=0.0.6
zuko>=0.0.8
5 changes: 3 additions & 2 deletions sphinx/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ Tutorials
1. Simulators and datasets <https://github.com/francois-rozet/lampe/blob/docs/tutorials/01_simulators.ipynb>
2. Neural posterior estimation <https://github.com/francois-rozet/lampe/blob/docs/tutorials/02_npe.ipynb>
3. Neural ratio estimation <https://github.com/francois-rozet/lampe/blob/docs/tutorials/03_nre.ipynb>
4. Expected coverage <https://github.com/francois-rozet/lampe/blob/docs/tutorials/04_coverage.ipynb>
5. Embedding and GPU <https://github.com/francois-rozet/lampe/blob/docs/tutorials/05_embedding.ipynb>
4. Neural score estimation <https://github.com/francois-rozet/lampe/blob/docs/tutorials/04_nse.ipynb>
5. Expected coverage <https://github.com/francois-rozet/lampe/blob/docs/tutorials/05_coverage.ipynb>
6. Embedding and GPU <https://github.com/francois-rozet/lampe/blob/docs/tutorials/06_embedding.ipynb>
47 changes: 47 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,53 @@ def test_AMNPELoss():
assert l.requires_grad


def test_NSE():
estimator = NSE(3, 5)

# Non-batched
theta, x, t = randn(3), randn(5), torch.tensor(0.5)
score = estimator(theta, x, t)

assert score.shape == (3,)
assert score.requires_grad

# Batched
theta, x, t = randn(256, 3), randn(256, 5), randn(256)
score = estimator(theta, x, t)

assert score.shape == (256, 3)

# Mixed
theta, x, t = randn(256, 3), randn(5), randn(1)
score = estimator(theta, x, t)

assert score.shape == (256, 3)

# Sample
x = randn(32, 5)
theta = estimator.flow(x).sample((8,))

assert theta.shape == (8, 32, 3)

# Log-density
with torch.no_grad():
log_p = estimator.flow(x).log_prob(theta)

assert log_p.shape == (8, 32)


def test_NSELoss():
estimator = NSE(3, 5)
loss = NSELoss(estimator)

theta, x = randn(256, 3), randn(256, 5)

l = loss(theta, x)

assert l.shape == ()
assert l.requires_grad


def test_MetropolisHastings():
log_f = lambda x: -(x**2).sum(dim=-1) / 2
f = lambda x: torch.exp(log_f(x))
Expand Down
2 changes: 1 addition & 1 deletion tutorials/02_npe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
"\n",
"estimator.train()\n",
"\n",
"with tqdm(range(64), unit='epoch') as tq:\n",
"with tqdm(range(64), unit='epoch', ncols=88) as tq:\n",
" for epoch in tq:\n",
" losses = torch.stack([\n",
" step(loss(theta, x))\n",
Expand Down
4 changes: 2 additions & 2 deletions tutorials/03_nre.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"\n",
"estimator.train()\n",
"\n",
"with tqdm(range(128), unit='epoch') as tq:\n",
"with tqdm(range(128), unit='epoch', ncols=88) as tq:\n",
" for epoch in tq:\n",
" losses = torch.stack([\n",
" step(loss(theta, x))\n",
Expand Down Expand Up @@ -186,7 +186,7 @@
"estimator.eval()\n",
"\n",
"with torch.no_grad():\n",
" theta_0 = prior.sample((4096,)) # 1024 concurrent Markov chains\n",
" theta_0 = prior.sample((1024,)) # 1024 concurrent Markov chains\n",
" log_p = lambda theta: estimator(theta, x_star) + prior.log_prob(theta) # p(theta | x) = r(theta, x) p(theta)\n",
"\n",
" sampler = MetropolisHastings(theta_0, log_f=log_p, sigma=0.5)\n",
Expand Down
278 changes: 278 additions & 0 deletions tutorials/04_nse.ipynb

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.