Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training refactoring #303

Merged
merged 11 commits into from
Apr 24, 2024
43 changes: 34 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,16 @@ you can send a SLURM job to the cluster using the script in [tools](tools/slurm_
sbatch slurm_run.sh
```

### Quick start

## Running Affinity-VAE: A quick start

We have a [tutorial](tutorials/README.md) on how to run Affinity-VAE on the
MNIST dataset. We recommend to start there for the first time you run
Affinity-VAE.

<details>
<summary><i>Affinity-VAE configuration parameters</i></summary>

Affinity-VAE has a running script (`run.py`) that allows you to configure and
run the code. You can look at the available configuration options by running:

Expand Down Expand Up @@ -128,9 +132,22 @@ Options:
-de, --depth INTEGER Depth of the convolutional layers (default
3).
-ch, --channels INTEGER First layer channels (default 64).
-fl, --filters TEXT Comma-separated list of filters for the
network. Either provide filters, or capacity
and depth.
-ld, --latent_dims INTEGER Latent space dimension (default 10).
-pd, --pose_dims INTEGER If pose on, number of pose dimensions. If 0
and gamma=0 it becomesa standard beta-VAE.
-bn_enc, --bnorm_encoder Batch normalisation in encoder is on if
True.
-bn_dec, --bnorm_decoder Batch normalisation in encoder is on if
True.
-gsdcl, --gsd_conv_layers INTEGER
The number of output channels for the
convolution layers at the end of the GSD
decoder
-spl, --n_splats INTEGER Number of Gaussian splats.
-kr, --klreduction TEXT Mean or sum reduction on KLD term.
-be, --beta FLOAT Beta maximum in the case of cyclical
annealing schedule
-bl, --beta_load The path to the saved beta array file to be
Expand Down Expand Up @@ -165,7 +182,8 @@ Options:
-ev, --eval Evaluate test data.
-dn, --dynamic Enable collecting meta and dynamic latent
space plots.
-m, --model TEXT Choose model to run.
-m, --model TEXT Choose model to run. The choice of models
are a, b, u and gsd
-vl, --vis_los Visualise loss (every epoch starting at
epoch 2).
-vac, --vis_acc Visualise confusion matrix and F1 scores
Expand All @@ -181,21 +199,25 @@ Options:
-vps, --vis_pos Visualise pose disentanglement (frequency
controlled).
-vpsc, --vis_pose_class TEXT Example: A,B,C. your deliminator should be
commas and no spaces .Classes to be used for
commas and no spaces. Classes to be used for
pose interpolation (a seperate pose
interpolation figure would be created for
each class).
-vpsc, --vis_z_n_int TEXT Number of Latent interpolation classes to to be printed, number of interpolation steps in each plot.
Example: 1,10. 1 plot with 10 interpolation steps between two classes.
your deliminator should be commas and no spaces.
-vzni, --vis_z_n_int TEXT Number of Latent interpolation classes to to
be printed, number of interpolation steps in
each plot. Example: 1,10. 1 plot with 10
interpolation steps between two classes.
your deliminator should be commas and no
spaces.
-vc, --vis_cyc Visualise cyclical parameters (once per
run).
-va, --vis_aff Visualise affinity matrix (once per run).
-his, --vis_his Visualise train-val class distribution (once
per run).
-similarity, --vis_sim Visualise train-val model similarity matrix.
-va, --vis_all Visualise all above.
-vf, --vis_format The format of saved images. Options: png , pdf
-vf, --vis_format TEXT The format of saved images. Options: png ,
pdf
-fev, --freq_eval INTEGER Frequency at which to evaluate test set.
-fs, --freq_sta INTEGER Frequency at which to save state
-fac, --freq_acc INTEGER Frequency at which to visualise confusion
Expand All @@ -222,17 +244,20 @@ Options:
-nrm, --normalise Normalise data
-sftm, --shift_min Shift the minimum of the data to one zero
and the maximum to one
-res --rescale Rescale images to given value (tuple, one
-res, --rescale INTEGER Rescale images to given value (tuple, one
value per dim).
-tb, --tensorboard Log metrics and figures to tensorboard
during training
-st, --strategy TEXT Define the strategy for distributed
training. Options are: 'ddp', 'deepspeed' or
'fsdp
--help Show this message and exit.
```

Note that setting `-g/--gamma` to `0` and `-pd/--pose_dims` to `0` will run a
vanilla beta-VAE.
</details>

### Quickstart

#### Configuring from the command line

Expand Down
33 changes: 33 additions & 0 deletions avae/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import logging

import torch

Expand Down Expand Up @@ -34,3 +35,35 @@ def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor):
raise NotImplementedError(
"Reparameterize method must be implemented in child class."
)


def set_layer_dim(
ndim: SpatialDims | int,
) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
if ndim == SpatialDims.TWO:
return torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d
elif ndim == SpatialDims.THREE:
return torch.nn.Conv3d, torch.nn.ConvTranspose3d, torch.nn.BatchNorm3d
else:
logging.error("Data must be 2D or 3D.")
exit(1)


def dims_after_pooling(start: int, n_pools: int) -> int:
"""Calculate the size of a layer after n pooling ops.

Parameters
----------
start: int
The size of the layer before pooling.
n_pools: int
The number of pooling operations.

Returns
-------
int
The size of the layer after pooling.


"""
return start // (2**n_pools)
76 changes: 76 additions & 0 deletions avae/cyc_annealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,82 @@
import numpy as np


def configure_annealing(
epochs: int,
value_max: float,
value_min: float,
cyc_method: str,
n_cycle: int,
ratio: float,
cycle_load: str | None = None,
):
"""
This function is used to configure the annealing of the beta and gamma values.
It creates an array of values that oscillate between a maximum and minimum value
for a defined number of cycles. This is used for gamma and beta in the loss term.
The function also allows for the loading of a pre-existing array of beta or gamma values.

Parameters
----------
epochs: int
Number of epochs in training
value_max: float
Maximum value of the beta or gamma
value_min: float
Minimum value of the beta or gamma
cyc_method : str
The method for constructing the cyclical mixing weight
- Flat : regular beta-vae
- Linear
- Sigmoid
- Cosine
- ramp
- delta
n_cycle: int
Number of cycles of the variable to oscillate between min and max
during the epochs
ratio: float
Ratio of increase during ramping
cycle_load: str | None
Path to a file containing the beta or gamma array

Returns
-------
cycle_arr: np.ndarray
Array of beta or gamma values

"""
if value_max == 0 and cyc_method != "flat" and cycle_load is not None:
raise RuntimeError(
"The maximum value for beta is set to 0, it is not possible to"
"oscillate between a maximum and minimum. Please choose the flat method for"
"cyc_method_beta"
)

if cycle_load is None:
# If a path for loading the beta array is not provided,
# create it given the input
cycle_arr = (
cyc_annealing(
epochs,
cyc_method,
n_cycle=n_cycle,
ratio=ratio,
).var
* (value_max - value_min)
+ value_min
)
else:
cycle_arr = np.load(cycle_load)
if len(cycle_arr) != epochs:
raise RuntimeError(
f"The length of the beta array loaded from file is {len(cycle_arr)} but the number of Epochs specified in the input are {epochs}.\n"
"These two values should be the same."
)

return cycle_arr


class cyc_annealing:

"""
Expand Down
28 changes: 14 additions & 14 deletions avae/decoders/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import numpy as np
import torch

from avae.base import dims_after_pooling, set_layer_dim
from avae.decoders.base import AbstractDecoder
from avae.models import dims_after_pooling, set_layer_dim


class Decoder(AbstractDecoder):
"""Affinity decoder. Includes optional pose component merge.

Parameters
----------
input_size: tuple (X, Y) or tuple (X, Y, Z)
input_shape: tuple (X, Y) or tuple (X, Y, Z)
Tuple representing the size of the data for each image
dimension X, Y and Z.
latent_dims: int
Expand All @@ -33,7 +33,7 @@ class Decoder(AbstractDecoder):

def __init__(
self,
input_size: tuple,
input_shape: tuple,
capacity: int | None = None,
filters: list[int] | None = None,
depth: int = 4,
Expand Down Expand Up @@ -76,22 +76,22 @@ def __init__(
assert all(
[
int(x) == x
for x in np.array(input_size) / (2 ** len(self.filters))
for x in np.array(input_shape) / (2 ** len(self.filters))
]
), (
"Input size not compatible with --depth. Input must be divisible "
"by {}.".format(2 ** len(self.filters))
)

self.bottom_dim = tuple(
[int(i / (2 ** len(self.filters))) for i in input_size]
[int(i / (2 ** len(self.filters))) for i in input_shape]
)

# define layer dimensions
CONV, TCONV, BNORM = set_layer_dim(len(input_size))
CONV, TCONV, BNORM = set_layer_dim(len(input_shape))

else:
self.bottom_dim = input_size
self.bottom_dim = input_shape

if latent_dims <= 0:
raise RuntimeError(
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward(self, x, x_pose):
class DecoderA(AbstractDecoder):
def __init__(
self,
input_size: tuple,
input_shape: tuple,
capacity: int = 8,
depth: int = 4,
latent_dims: int = 8,
Expand All @@ -189,7 +189,7 @@ def __init__(
self.bnorm = bnorm

assert all(
[int(x) == x for x in np.array(input_size) / (2**depth)]
[int(x) == x for x in np.array(input_shape) / (2**depth)]
), (
"Input size not compatible with --depth. Input must be divisible "
"by {}.".format(2**depth)
Expand All @@ -200,7 +200,7 @@ def __init__(
[
filters[-1],
]
+ [dims_after_pooling(ax, depth) for ax in input_size]
+ [dims_after_pooling(ax, depth) for ax in input_shape]
)
flat_shape = np.prod(unflat_shape)

Expand Down Expand Up @@ -276,7 +276,7 @@ class DecoderB(AbstractDecoder):

def __init__(
self,
input_size: tuple,
input_shape: tuple,
capacity: int,
depth: int = 4,
latent_dims: int = 8,
Expand All @@ -288,13 +288,13 @@ def __init__(
self.pose = not (pose_dims == 0)

assert all(
[int(x) == x for x in np.array(input_size) / (2**depth)]
[int(x) == x for x in np.array(input_shape) / (2**depth)]
), (
"Input size not compatible with --depth. Input must be divisible "
"by {}.".format(2**depth)
)
_, TCONV, BNORM = set_layer_dim(len(input_size))
self.bottom_dim = tuple([int(i / (2**depth)) for i in input_size])
_, TCONV, BNORM = set_layer_dim(len(input_shape))
self.bottom_dim = tuple([int(i / (2**depth)) for i in input_shape])

# iteratively define deconvolution and batch normalisation layers
self.conv_dec = torch.nn.ModuleList()
Expand Down
Loading
Loading