Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to FrEIA that came about from the IB-INN paper #52

Merged
merged 2 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion FrEIA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
structure of operations.'''
from . import framework
from . import modules
from . import dummy_modules

__all__ = ["framework", "modules"]
16 changes: 16 additions & 0 deletions FrEIA/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
'''The framework module contains the logic used in building the graph and
inferring the order that the nodes have to be executed in forward and backward
direction.'''

from .reversible_graph_net import *
from .reversible_sequential_net import *

__all__ = [
'ReversibleSequential',
'ReversibleGraphNet',
'Node',
'InputNode',
'ConditionNode',
'OutputNode'
]

File renamed without changes.
24 changes: 0 additions & 24 deletions FrEIA/framework.py → FrEIA/framework/reversible_graph_net.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
'''The framework module contains the logic used in building the graph and
inferring the order that the nodes have to be executed in forward and backward
direction.'''

import sys
import warnings
import numpy as np
Expand Down Expand Up @@ -499,23 +495,3 @@ def get_module_by_name(self, name):
return node.module
except:
return None



# Testing example
if __name__ == '__main__':
inp = InputNode(4, 64, 64, name='input')
t1 = Node([(inp, 0)], dummys.dummy_mux, {}, name='t1')
s1 = Node([(t1, 0)], dummys.dummy_2split, {}, name='s1')

t2 = Node([(s1, 0)], dummys.dummy_module, {}, name='t2')
s2 = Node([(s1, 1)], dummys.dummy_2split, {}, name='s2')
t3 = Node([(s2, 0)], dummys.dummy_module, {}, name='t3')

m1 = Node([(t3, 0), (s2, 1)], dummys.dummy_2merge, {}, name='m1')
m2 = Node([(t2, 0), (m1, 0)], dummys.dummy_2merge, {}, name='m2')
outp = OutputNode([(m2, 0)], name='output')

all_nodes = [inp, outp, t1, s1, t2, s2, t3, m1, m2]

net = ReversibleGraphNet(all_nodes, 0, 1)
76 changes: 76 additions & 0 deletions FrEIA/framework/reversible_sequential_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch.nn as nn
import torch

class ReversibleSequential(nn.Module):
'''Simpler than FrEIA.framework.ReversibleGraphNet:
Only supports a sequential series of modules (no splitting, merging, branching off).
Has an append() method, to add new blocks in a more simple way than the computation-graph
based approach of ReversibleGraphNet. For example:

inn = ReversibleSequential(channels, dims_H, dims_W)

for i in range(n_blocks):
inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True)
inn.append(FrEIA.modules.HaarDownsampling)
# and so on

'''

def __init__(self, *dims):
super().__init__()

self.shapes = [tuple(dims)]
self.conditions = []
self.module_list = nn.ModuleList()

def append(self, module_class, cond=None, cond_shape=None, **kwargs):
'''Append a reversible block from FrEIA.modules to the network.
module_class: Class from FrEIA.modules.
cond (int): index of which condition to use (conditions will be passed as list to forward()).
Conditioning nodes are not needed for ReversibleSequential.
cond_shape (tuple[int]): the shape of the condition tensor.
**kwargs: Further keyword arguments that are passed to the constructor of module_class (see example).
'''

dims_in = [self.shapes[-1]]
self.conditions.append(cond)

if cond is not None:
kwargs['dims_c'] = [cond_shape]

module = module_class(dims_in, **kwargs)
self.module_list.append(module)
ouput_dims = module.output_dims(dims_in)
assert len(ouput_dims) == 1, "Module has more than one output"
self.shapes.append(ouput_dims[0])


def forward(self, x, c=None, rev=False):
'''
x (Tensor): input tensor (in contrast to ReversibleGraphNet, a list of tensors is not
supported, as ReversibleSequential only has one input).
c (list[Tensor]): list of conditions.
rev: whether to compute the network forward or reversed.

Returns
z (Tensor): network output.
jac (Tensor): log-jacobian-determinant.
There is no separate log_jacobian() method, it is automatically computed during forward().
'''

iterator = range(len(self.module_list))
jac = 0

if rev:
iterator = reversed(iterator)

for i in iterator:
if self.conditions[i] is None:
x, j = (self.module_list[i]([x], rev=rev)[0],
self.module_list[i].jacobian(x, rev=rev))
else:
x, j = (self.module_list[i]([x], c=[c[self.conditions[i]]], rev=rev)[0],
self.module_list[i].jacobian(x, c=[c[self.conditions[i]]], rev=rev))
jac = j + jac

return x, jac
4 changes: 3 additions & 1 deletion FrEIA/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

Coupling blocks:

* AllInOneBlock
* NICECouplingBlock
* RNVPCouplingBlock
* GLOWCouplingBlock
Expand Down Expand Up @@ -32,7 +33,6 @@

Graph topology:


* SplitChannel
* ConcatChannel
* Split1D
Expand All @@ -49,6 +49,7 @@

'''

from .all_in_one_block import *
from .fixed_transforms import *
from .reshapes import *
from .coupling_layers import *
Expand All @@ -60,6 +61,7 @@
from .gaussian_mixture import *

__all__ = [
'AllInOneBlock',
'glow_coupling_layer',
'rev_layer',
'rev_multiplicative_layer',
Expand Down
212 changes: 212 additions & 0 deletions FrEIA/modules/all_in_one_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import pdb
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import special_ortho_group

class AllInOneBlock(nn.Module):
''' Combines affine coupling, permutation, global affine transformation ('ActNorm')
in one block.'''

def __init__(self, dims_in, dims_c=[],
subnet_constructor=None,
affine_clamping=2.,
gin_block=False,
global_affine_init=1.,
global_affine_type='SOFTPLUS',
permute_soft=False,
learned_householder_permutation=0,
reverse_permutation=False):
'''
subnet_constructor: class or callable f, called as
f(channels_in, channels_out) and should return a torch.nn.Module

affine_clamping: clamp the output of the mutliplicative coefficients
(before exponentiation) to +/- affine_clamping.

gin_block: Turn the block into a GIN block from Sorrenson et al, 2019

global_affine_init: Initial value for the global affine scaling beta

global_affine_init: 'SIGMOID', 'SOFTPLUS', or 'EXP'. Defines the activation
to be used on the beta for the global affine scaling.

permute_soft: bool, whether to sample the permutation matrices from SO(N),
or to use hard permutations in stead. Note, permute_soft=True is very slow
when working with >512 dimensions.

learned_householder_permutation: Int, if >0, use that many learned householder
reflections. Slow if large number. Dubious whether it actually helps.

reverse_permutation: Reverse the permutation before the block, as introduced by
Putzky et al, 2019.
'''

super().__init__()

channels = dims_in[0][0]
self.Ddim = len(dims_in[0]) - 1
self.sum_dims = tuple(range(1, 2 + self.Ddim))

if len(dims_c) == 0:
self.conditional = False
self.condition_channels = 0
elif len(dims_c) == 1:
self.conditional = True
self.condition_channels = dims_c[0][0]
assert tuple(dims_c[0][1:]) == tuple(dims_in[0][1:]), \
F"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}."
else:
raise ValueError('Only supports one condition (concatenate externally)')

split_len1 = channels - channels // 2
split_len2 = channels // 2
self.splits = [split_len1, split_len2]


try:
self.permute_function = {0 : F.linear,
1 : F.conv1d,
2 : F.conv2d,
3 : F.conv3d}[self.Ddim]
except KeyError:
raise ValueError(f"Data has {1 + self.Ddim} dimensions. Must be 1-4.")

self.in_channels = channels
self.clamp = affine_clamping
self.GIN = gin_block
self.welling_perm = reverse_permutation
self.householder = learned_householder_permutation

if permute_soft and channels > 512:
warnings.warn(("Soft permutation will take a very long time to initialize "
f"with {channels} feature channels. Consider using hard permutation instead."))

if global_affine_type == 'SIGMOID':
global_scale = np.log(global_affine_init)
self.global_scale_activation = (lambda a: 10 * torch.sigmoid(a - 2.))
elif global_affine_type == 'SOFTPLUS':
global_scale = 10. * global_affine_init
self.softplus = nn.Softplus(beta=0.5)
self.global_scale_activation = (lambda a: 0.1 * self.softplus(a))
elif global_affine_type == 'EXP':
global_scale = np.log(global_affine_init)
self.global_scale_activation = (lambda a: torch.exp(a))
else:
raise ValueError('Please, SIGMOID, SOFTPLUS or EXP, as global affine type')

self.global_scale = nn.Parameter(torch.ones(1, self.in_channels, *([1] * self.Ddim)) * float(global_scale))
self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.Ddim)))

if permute_soft:
w = special_ortho_group.rvs(channels)
else:
w = np.zeros((channels,channels))
for i,j in enumerate(np.random.permutation(channels)):
w[i,j] = 1.

if self.householder:
self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True)
self.w = None
self.w_inv = None
self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False)
else:
self.w = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.Ddim)),
requires_grad=False)
self.w_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.Ddim)),
requires_grad=False)

self.s = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1])
self.last_jac = None

def construct_householder_permutation(self):
w = self.w_0
for vk in self.vk_householder:
w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk))

for i in range(self.Ddim):
w = w.unsqueeze(-1)
return w

def log_e(self, s):
s = self.clamp * torch.tanh(0.1 * s)
if self.GIN:
s -= torch.mean(s, dim=self.sum_dims, keepdim=True)
return s

def permute(self, x, rev=False):
if self.GIN:
scale = 1.
else:
scale = self.global_scale_activation( self.global_scale)
if rev:
return (self.permute_function(x, self.w_inv) - self.global_offset) / scale
else:
return self.permute_function(x * scale + self.global_offset, self.w)

def pre_permute(self, x, rev=False):
if rev:
return self.permute_function(x, self.w)
else:
return self.permute_function(x, self.w_inv)

def affine(self, x, a, rev=False):
ch = x.shape[1]
sub_jac = self.log_e(a[:,:ch])
if not rev:
return (x * torch.exp(sub_jac) + 0.1 * a[:,ch:],
torch.sum(sub_jac, dim=self.sum_dims))
else:
return ((x - 0.1 * a[:,ch:]) * torch.exp(-sub_jac),
-torch.sum(sub_jac, dim=self.sum_dims))

def forward(self, x, c=[], rev=False):
if self.householder:
self.w = self.construct_householder_permutation()
if rev or self.welling_perm:
self.w_inv = self.w.transpose(0,1).contiguous()

if rev:
x = [self.permute(x[0], rev=True)]
elif self.welling_perm:
x = [self.pre_permute(x[0], rev=False)]

x1, x2 = torch.split(x[0], self.splits, dim=1)

if self.conditional:
x1c = torch.cat([x1, *c], 1)
else:
x1c = x1

if not rev:
a1 = self.s(x1c)
x2, j2 = self.affine(x2, a1)
else:
# names of x and y are swapped!
a1 = self.s(x1c)
x2, j2 = self.affine(x2, a1, rev=True)

self.last_jac = j2
x_out = torch.cat((x1, x2), 1)

n_pixels = 1
for d in self.sum_dims[1:]:
n_pixels *= x_out.shape[d]

self.last_jac += ((-1)**rev * n_pixels) * (1 - int(self.GIN)) * (torch.log(self.global_scale_activation(self.global_scale) + 1e-12).sum())

if not rev:
x_out = self.permute(x_out, rev=False)
elif self.welling_perm:
x_out = self.pre_permute(x_out, rev=True)

return [x_out]

def jacobian(self, x, c=[], rev=False):
return self.last_jac

def output_dims(self, input_dims):
return input_dims