From cf4e8850cffb636dedb253751ceaaf0ce19860af Mon Sep 17 00:00:00 2001 From: Lynton Ardizzone Date: Tue, 2 Feb 2021 15:47:37 +0100 Subject: [PATCH] added ReversibleSequential net (simpler than RevGraphNet for sequential architectures). See docstring. --- FrEIA/__init__.py | 1 - FrEIA/framework/__init__.py | 16 ++++ FrEIA/{ => framework}/dummy_modules.py | 0 .../reversible_graph_net.py} | 24 ------ FrEIA/framework/reversible_sequential_net.py | 76 +++++++++++++++++++ 5 files changed, 92 insertions(+), 25 deletions(-) create mode 100644 FrEIA/framework/__init__.py rename FrEIA/{ => framework}/dummy_modules.py (100%) rename FrEIA/{framework.py => framework/reversible_graph_net.py} (95%) create mode 100644 FrEIA/framework/reversible_sequential_net.py diff --git a/FrEIA/__init__.py b/FrEIA/__init__.py index f0b2a99..372c81c 100644 --- a/FrEIA/__init__.py +++ b/FrEIA/__init__.py @@ -3,6 +3,5 @@ structure of operations.''' from . import framework from . import modules -from . import dummy_modules __all__ = ["framework", "modules"] diff --git a/FrEIA/framework/__init__.py b/FrEIA/framework/__init__.py new file mode 100644 index 0000000..5464feb --- /dev/null +++ b/FrEIA/framework/__init__.py @@ -0,0 +1,16 @@ +'''The framework module contains the logic used in building the graph and +inferring the order that the nodes have to be executed in forward and backward +direction.''' + +from .reversible_graph_net import * +from .reversible_sequential_net import * + +__all__ = [ + 'ReversibleSequential', + 'ReversibleGraphNet', + 'Node', + 'InputNode', + 'ConditionNode', + 'OutputNode' + ] + diff --git a/FrEIA/dummy_modules.py b/FrEIA/framework/dummy_modules.py similarity index 100% rename from FrEIA/dummy_modules.py rename to FrEIA/framework/dummy_modules.py diff --git a/FrEIA/framework.py b/FrEIA/framework/reversible_graph_net.py similarity index 95% rename from FrEIA/framework.py rename to FrEIA/framework/reversible_graph_net.py index 5bf9b04..2539153 100644 --- a/FrEIA/framework.py +++ b/FrEIA/framework/reversible_graph_net.py @@ -1,7 +1,3 @@ -'''The framework module contains the logic used in building the graph and -inferring the order that the nodes have to be executed in forward and backward -direction.''' - import sys import warnings import numpy as np @@ -499,23 +495,3 @@ def get_module_by_name(self, name): return node.module except: return None - - - -# Testing example -if __name__ == '__main__': - inp = InputNode(4, 64, 64, name='input') - t1 = Node([(inp, 0)], dummys.dummy_mux, {}, name='t1') - s1 = Node([(t1, 0)], dummys.dummy_2split, {}, name='s1') - - t2 = Node([(s1, 0)], dummys.dummy_module, {}, name='t2') - s2 = Node([(s1, 1)], dummys.dummy_2split, {}, name='s2') - t3 = Node([(s2, 0)], dummys.dummy_module, {}, name='t3') - - m1 = Node([(t3, 0), (s2, 1)], dummys.dummy_2merge, {}, name='m1') - m2 = Node([(t2, 0), (m1, 0)], dummys.dummy_2merge, {}, name='m2') - outp = OutputNode([(m2, 0)], name='output') - - all_nodes = [inp, outp, t1, s1, t2, s2, t3, m1, m2] - - net = ReversibleGraphNet(all_nodes, 0, 1) diff --git a/FrEIA/framework/reversible_sequential_net.py b/FrEIA/framework/reversible_sequential_net.py new file mode 100644 index 0000000..38ccb33 --- /dev/null +++ b/FrEIA/framework/reversible_sequential_net.py @@ -0,0 +1,76 @@ +import torch.nn as nn +import torch + +class ReversibleSequential(nn.Module): + '''Simpler than FrEIA.framework.ReversibleGraphNet: + Only supports a sequential series of modules (no splitting, merging, branching off). + Has an append() method, to add new blocks in a more simple way than the computation-graph + based approach of ReversibleGraphNet. For example: + + inn = ReversibleSequential(channels, dims_H, dims_W) + + for i in range(n_blocks): + inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True) + inn.append(FrEIA.modules.HaarDownsampling) + # and so on + + ''' + + def __init__(self, *dims): + super().__init__() + + self.shapes = [tuple(dims)] + self.conditions = [] + self.module_list = nn.ModuleList() + + def append(self, module_class, cond=None, cond_shape=None, **kwargs): + '''Append a reversible block from FrEIA.modules to the network. + module_class: Class from FrEIA.modules. + cond (int): index of which condition to use (conditions will be passed as list to forward()). + Conditioning nodes are not needed for ReversibleSequential. + cond_shape (tuple[int]): the shape of the condition tensor. + **kwargs: Further keyword arguments that are passed to the constructor of module_class (see example). + ''' + + dims_in = [self.shapes[-1]] + self.conditions.append(cond) + + if cond is not None: + kwargs['dims_c'] = [cond_shape] + + module = module_class(dims_in, **kwargs) + self.module_list.append(module) + ouput_dims = module.output_dims(dims_in) + assert len(ouput_dims) == 1, "Module has more than one output" + self.shapes.append(ouput_dims[0]) + + + def forward(self, x, c=None, rev=False): + ''' + x (Tensor): input tensor (in contrast to ReversibleGraphNet, a list of tensors is not + supported, as ReversibleSequential only has one input). + c (list[Tensor]): list of conditions. + rev: whether to compute the network forward or reversed. + + Returns + z (Tensor): network output. + jac (Tensor): log-jacobian-determinant. + There is no separate log_jacobian() method, it is automatically computed during forward(). + ''' + + iterator = range(len(self.module_list)) + jac = 0 + + if rev: + iterator = reversed(iterator) + + for i in iterator: + if self.conditions[i] is None: + x, j = (self.module_list[i]([x], rev=rev)[0], + self.module_list[i].jacobian(x, rev=rev)) + else: + x, j = (self.module_list[i]([x], c=[c[self.conditions[i]]], rev=rev)[0], + self.module_list[i].jacobian(x, c=[c[self.conditions[i]]], rev=rev)) + jac = j + jac + + return x, jac