-
Notifications
You must be signed in to change notification settings - Fork 1.1k
MONAI_Network_Design_Discussion
Networks are a place in the design where a number of design decisions converge. They represent a choke point in the design where any issues around the design of layers and layer factories will manifest in issues when trying to find reusable elements to networks.
Network functionality represents a major design opportunity for MONAI. Pytorch is very much unopinionated in how networks are defined. It provides Module
as a base class from which to create a network, and a few methods that must be implemented, but there is no prescribed pattern nor much helper functionality for initialising networks. This leaves a lot of room for defining some useful 'best practice' patterns for constructing new networks in MONAI. Although trivial, inflexible network implementations are easy enough, we can give users a toolset that makes it much easier to build well-engineered, flexible networks, and demonstrate their value by committing to use them in the networks that we build
- layers (covered here)
- layer factories (covered here)
- network definition
- compatibility with torchscript
- reference and configurable versions
- configurability
- modules
- blocks
- structure
- width
- depth
- rules-based configuration
- modules
- network structure
- dimension-agnosticism
- recursion vs iteration
- initialisation / model restore
- network support
- hyper-parameter search techniques
- higher-order training techniques
- adaptive loss schemes
- dynamic curriculum learning techniques
- MVP
- provision of layers
- provision of layer factories (if agreed upon)
- provision of reference network implementations
- provision of configurable network implementations sufficient for MVP
- configuration of modular blocks
- configuration of structure (layer counts etc.)
- provision of network utilities
- post-MVP
- higher-order training techniques
- refinement of configurable network implementations
- model size calculation
- longer term
- higher-order training techniques
- hyper-parameter search techniques
This topic is covered in detail here and here.
Compatibility with torchscript is a key capabilities. All mechanisms that we crease for constructing networks must be torchscript compatible, and this imposes restrictions on how such mechanisms are implemented. This can impact models using features such as:
- tied weights
- densely-connected network blocks (Eric please elaborate)
We should do a full survey of torchscript-related issues that must be avoided in our network functionality
Every network type that comes from a paper should have a plain, 'unconfigurable' implementation the purpose of which is to allow people to replicate results, and be clear from an understanding standpoint. Such networks can also be used as regression test sources for more configurable network implementations.
Outside of reference networks, which are implemented in a way that serves their particular purpose, our network implementations should have strong configurability as a primary goal. We should set standards of configurability that all such network implementation should meet, each of which is covered in a subsection.
Layers need to be attributes of a module for pytorch to be able to recognise them. This affects all aspects of configurability, as number of layers and number of downsamples / upsamples both require extra hoops to be jumped through beyond merely adding layers to a list. Pytorch provides (at least) the following mechanisms:
Module.__setattr__
Module.add_module
self.layers = nn.ModuleList()
nn.Sequential
Other patterns also exist, such as modular recursion (TODO: Eric already has example like this in the codebase; reference them here).
class RecursiveModule(nn.Module):
def __init__(self, entries):
self.cur = entries[0]
self.next = RecursiveModule(entries[1:]) if len(entries) > 1 else None
def forward(self, t_input):
if self.next:
return self.next(self.cur(t_input))
return self.cur(t_input)
Many networks have a natural modularity to them in terms of what is considered a 'unit' of computation. The variants of ResNet block are a good example of this. Such blocks should be replaceable modular elements in an overall network structure, as it is often the case that innovations on a base network architecture tend to vary in the nature of the blocks. Allowing modules to be a configurable element of a network design gives developers the means to experiment readily with innovations on existing Networks.
Structural configurability can be separated into two subtypes:
- Configuration of layer counts
- Recursive network architectures
Structural configurability is the area most impacted by pytorch's need to have layers be attributes of nn.Module
instances. Layer counts are affected by this and any of the solutions mentioned above are potential candidates to solve this problem.
Recursive network architectures are slightly more complex. Where layer counts represent a 'horizontal' configurability (aka network depth), down sample counts represent an example of 'vertical' configurability.
A Unet can be thought of as a recursive structure where each level of recursion is a given resolution.
Each resolution (achieved by downsampling) is a series of three concentric layers:
- A convolutional block (CB) that contains the convolutional Modules
- A convolutional layer (CL) that wraps the convolutional block with downsampling and upsampling modules
- A skipped layer (SL) that has two paths, one that goes through a skip module (nn.Identity or some other module that does work on the skip connection) and the convolutional layer that does the work, along with a concatenation
Eric's UNet implementation that we have used for a baseline configurable UNet does layering along these lines, through the calling of recursive functions.
Ben has been experimenting with ways of achieving a similar design but through an iterative approach. This is not part of any PR at this point, but looks like this:
class SkippedLayer(nn.Module):
def __init__(self, inner):
super(SkippedLayer, self).__init__()
self.skip = nn.Identity()
self.inner = inner
def forward(self, t_input):
return self.skip(t_input) + self.inner(t_input)
class ConvolutionalLayer(nn.Module):
def __init__(self, inner):
super(ConvolutionalLayer, self).__init__()
self.model([DownSample, inner, UpSample])
def forward(self, t_input):
return self.model(t_input)
class ConvolutionalBlock(nn.Module):
def __init__(self, block_fn, enc_count, inner, dec_count):
super(ConvolutionalBlock, self).__init__()
if inner is not None:
self.model = nn.Sequential([block_fn() * enc_count] + [inner] + [block_fn() * dec_count])
else:
self.model = nn.Sequential([block_fn() * enc_count] + [block_fn() * dec_count])
def forward(self, t_input):
return self.model(t_input)
@staticmethod
def factory(*args, **kwargs):
return ConvolutionalBlock(*args, **kwargs)
class UnetFramework(nn.Module):
def __init__(self, initial_fn, final_fn):
super(UnetFramework, self).__init__()
encoder_counts = [1, 2, 2, 4]
decoder_counts = [1, 1, 1, 2]
layers = []
for i in len(encoder_counts):
layers.append(SkippedLayer)
layers.append(ConvolutionalLayer)
layers.append(partial(ConvolutionalBlock.factory, ResNetBlock(), enc_count=encoder_counts[i], dec_count=decoder_counts[i]))
# TODO: refactor out into a function that wires up the blocks recursively from the specified array
inner = None
for layer in reversed(layers):
inner = layer(inner)
self.model = inner
def forward(self, t_input):
self.model(t_input)
TODO
TODO
Network support refers to everything around network design that supports sophisticalted use of networks. This includes:
- hyperparameter exploration
- adaptive loss techniques, especially those involving network outputs
- adaptive sample selection techniques such as curriculum learning, especially those involving network outputs