Skip to content

Commit

Permalink
Improve network definition.
Browse files Browse the repository at this point in the history
Applied improvements to network layers to simplify the execution graph and increase robustness.
Expanded tests and included a notebook to visualize SW-MSA.
  • Loading branch information
MidnessX committed Jul 19, 2023
1 parent ffb921e commit bfc68bb
Show file tree
Hide file tree
Showing 9 changed files with 1,153 additions and 705 deletions.
21 changes: 20 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@

# Created by https://www.toptal.com/developers/gitignore/api/linux,visualstudiocode,python
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,visualstudiocode,python

### Linux ###
*~

# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*

# KDE directory preferences
.directory

# Linux trash folder which might appear on any partition or disk
.Trash-*

# .nfs files are created when an open file is removed but is still being accessed
.nfs*

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -185,4 +204,4 @@ cython_debug/
# Ignore code-workspaces
*.code-workspace

# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
# End of https://www.toptal.com/developers/gitignore/api/linux,visualstudiocode,python
11 changes: 11 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}
162 changes: 69 additions & 93 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,151 +4,127 @@

![Swin Transformer architecture](https://github.com/microsoft/Swin-Transformer/blob/3b0685bf2b99b4cf5770e47260c0f0118e6ff1bb/figures/teaser.png)

This is a TensorFlow 2.0 implementation of the [Swin Transformer architecture](https://arxiv.org/abs/2103.14030).
This is a Kears/TensorFlow 2.0 implementation of the [Swin Transformer architecture](https://arxiv.org/abs/2103.14030) inspired by the official Pytorch [code](https://github.com/microsoft/Swin-Transformer).

It is built using the Keras API following best practices, such as allowing complete serialization and deserialization of custom layers and deferring weight creation until the first call with real inputs.

This implementation is inspired by the [official version](https://github.com/microsoft/Swin-Transformer) offered by authors of the paper, while simultaneously improving in some areas such as shape and type checks.

## Installation

Clone the repository:
```bash
git clone [email protected]:MidnessX/swin.git
```
Enter into the directory:
Enter into it:
```bash
cd swin
```
Install the package via:
```bash
pip install -e .
pip install swin-transformer
```

## Usage

Class ``Swin`` in ``swin.model`` is a subclass of ``tf.keras.Model``, so you can instantiate Swin Transformers and train them through well known interface methods, such as ``compile()``, ``fit()``, ``save()``.

The only remark is the first argument to the ``Swin`` class constructor, which is expected to be a ``tf.Tensor`` object or equivalent, such as a symbolic tensor produced by ``tf.keras.Input``.
This tensor is only used to determine the shape of future inputs and can be an example coming from your dataset or any random tensor sharing its shape.

For convenience, ``swin.model`` also includes classes for variants of the Swin architecture described in the article (``SwinT``, ``SwinS``, ``SwinB``, ``SwinL``) which initialize a ``Swin`` object with the variant's parameters.

## Example

```python
import tensorflow as tf

from swin.model import SwinT

# Load the dataset as a list of mini batches
train_x = ...
train_y = ...
num_classes = ...
import tensorflow.keras as keras
from swin import Swin

# Take a mini batch from the dataset to build the model
mini_batch = train_x[0]
# Dataset loading, omitted for brevity
x = [...]
y = [...]
num_classes = [...]

model = SwinT(mini_batch, num_classes)
model = Swin(num_classes)

# Build the model by calling it for the first time
model(mini_batch)

# Compile the model
model.compile(
loss=tf.keras.losses.SGD(learning_rate=1e-3, momentum=0.9),
optimizer=tf.keras.optimizers.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()]
optimizer=keras.optimizers.AdamW(),
loss=keras.losses.CategoricalCrossentropy(),
metrics=[keras.metrics.CategoricalAccuracy()]
)

# Train the model
history = model.fit(train_x, train_y, epochs=300)
model.fit(
x,
y,
epochs=1000,
)

# Save the trained model
model.save("path/to/model/directory")
```

## Notes

- The input type accepted by the model is ``tf.float32``. Any pre-processing of data should include a conversion step of images from ``tf.uint8`` to ``tf.float32`` if necessary.

- Swin architectures have many parameters, so training them is not an easy task. Expect a lot of trial & error before honing in correct hyperparameters.

- ``SwinModule`` layers place the dimensionality reduction layer (``SwinPatchMerging``) after transformer layers (``SwinTransformer``), rather than before as found in the paper. This choice is to maintain consistency with the original network implementation.

## Testing

Test modules can be found under the ``tests`` folder of this repository.
They can be executed to test the expected functionality of custom layers for the Swin architecture, as well as basic functionalities of the whole model.

Admittedly these tests could be expanded and further improved to cover more cases, but they should be enough to verify general functionality.

## Assumptions and simplifications

While implementing the Swin Transformer architecture a number of assumptions and simplifications have been made:

1. Input images must have 3 channels.
This network has been built to be consistent with its [official Pytorch implementation](https://github.com/microsoft/Swin-Transformer).
This translates into the following statements:

2. The size of windows in (Shifted) Windows Multi-head Attention is fixed to 7[^1].
- The ratio of hidden to output neurons in MLP blocks is set to 4.
- Projection of input data to obtain `Q`, `K`, and `V` includes a bias term in all transformer blocks.
- `Q` and `K` are scaled by `sqrt(d)`, where `d` is the size of `Q` and `K`.
- No _Dropout_ is applied to attention heads.
- [_Stochastic Depth_](https://arxiv.org/pdf/1603.09382.pdf) is applied to randomly disable patches after the attention computation, with probability set to 10%.
- No absolute position information is added to embeddings.
- _Layer Normalizaton_ is applied to embeddings.
- The extraction of patches from images and the generation of embeddings both happen in the `SwinPatchEmbeddings` layer.
- Patch merging happens at the end of each stage, rather than at the beginning.
This simplifies the definition of layers and does not change the overall architecture.

3. The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4[^1].
Additionally, the following decisions have been made to simplify development:

4. A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention[^2].

5. ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``[^1].

6. No dropout is applied to attention heads[^2].

7. The probability of the Stochastic Depth computation-skipping technique during training is fixed to 0.1[^2].

8. No absolute position information is included in embeddings[^3].

9. ``LayerNormalization`` is applied after building patch embeddings[^2].

[^1]: To stay consistent with the content of the paper.

[^2]: In the original implementation this happens when using default arguments.

[^3]: Researchers note in the paper that adding absolute position information to embedding decreases network capabilities.
- The network only accepts square `tf.float32` images with 3 channels as inputs (i.e. height and width must be identical).
- No padding is applied to embeddings during the SW-MSA calculation, as their size is assumed to be a multiple of window size.

## Choosing parameters

### Dependencies

If using the base class (``Swin``), it is necessary to provide a series of parameters to instantiate the model.
The choice of these values is important and a series of dependencies exist between them.
When using any of the subclasses (``SwinT``, ``SwinS``, ``SwinB``, ``SwinL``), the architecture is fixed to their respective variants found in the paper.

The size of windows (``window_size``) used during (Shifted) Windows Multi-head Self Attention is the starting point and, as stated in the section about [assumptions](https://github.com/MidnessX/swin#assumptions-and-simplifications), it is fixed to ``7`` (as in the original paper).
When using the `Swin` class directly, however, you can customize the resulting architecture by specifing all the network's parameters.
This sections provides an overview of the dependencies existing between these parameters.

The resolution of inputs to network stages, expressed as the number of patches along each axis, must be a multiple of ``window_size`` and gets halved by every stage through ``SwinPatchMerging`` layers.
The suggestion is to choose a resolution for the final stage and multiply it by ``2`` for every stage in the desired model, obtaining the input resolution of the first stage (``resolution_stage_1``).
- Each stage has an input with shape `(batch_size, num_patches, num_patches, embed_dim)`.
`num_patches` must be a multiple of `window_size`.
- Each stage halves the `num_patches` dimension by merging four adjacent patches together.
It can be easier to choose a desired number of patches in the last stage and multiply it by 2 for every stage in the network to obtain the initial `num_patches` value.
- By multiplying `num_patches` by `patch_size` you can find out the size in pixels of input images.
- `embed_dim` must be a multiple of `num_heads` for every stage.
- The number of transformer blocks in each stage can be set freely, as they do not alter the shape of patches.

Input images to the ``Swin`` model must be squares, with their height/width given by multiplying ``resolution_stage_1`` with the desired size of patches (``patch_size``).
To better understand how to choose network parameters, consider the following example:

The number of ``SwinTransformer`` layers in each stage (``depths``) is arbitrary.
1. The depth is set to 3 stages.
2. Windows are set to be 8 patches wide (i.e. `window_size = 8`).
3. The last stage should have a `2 * window_size = 16` patch-wide input.
This means that the input to the second stage and the first stage will be 32x32 and 64x64 patch-wide respectively.
4. We require each patch to cover a 6x6 pixel area, so the initial image will be `num_patches * 6 = 64 * 6 = 384` pixel wide.
5. For the first stage, we choose 2 attention layers; 4 for the second, and 2 for the third.
6. The number of attention heads is set to 4.
This implies that there will be 8 attention heads in the second stage and 16 attention heads in the third stage.
7. Using the value found in the Swin paper, the `embed_dim / num_heads` ratio is set to 32, leading to an initial `embed_dim` of `32 * 4 = 128`.

The number of transformer heads (``num_heads``) should instead double at each stage.
Authors of the paper use a fixed ratio between embedding dimensions and the number of heads in each stage of the network, amounting to ``32``.
This means that, chosen the number of transformer heads in the first stage, it should be multiplied by ``32`` to obtain ``embed_dim``.
Summarizing, this is equal to:

The following example should help clarify these concepts.
- `image_size = 384`
- `patch_size = 6`
- `window_size = 8`
- `embed_dim = 128`
- `depths = [2, 4, 2]`
- `num_heads = [4, 8, 16]`

### Parameter choice example

Let's imagine we want a Swin Transformer having ``3`` stages.
The last stage (``stage_3``) should receive inputs of ``14x14`` patches (``14 = window_size * 2``); this also means that ``stage_2`` receives inputs of ``28x28`` patches and ``stage_1`` of ``56x56``.
## Testing

We want to convert our images into patches having size ``6x6``, so images should have size ``56 * 6 = 336``.
Test modules can be found under the ``tests`` folder of this repository.
They can be executed to verify the expected functionality of custom layers for the Swin architecture, as well as basic functionalities of the whole model.

Our network will have ``2`` transformers in the first stage, ``4`` in the second and ``2`` in the third.
We choose ``4`` heads for the first stage and thus the second one will have ``8`` heads while the third ``16``.
You can run them with the following command:
```bash
python -m unittest discover -s ./tests
```

With these numbers we can derive the size of embeddings used in the first stage by multiplying ``32`` by ``4``, giving us ``128``.
## Extras

Summarizing, we have:
To better understand how SW-MSA works, a Jupyter notebook found in the `extras` folder can be used to visualize window partitioning, traslation and mask construction.

- ``image_size = 336``
- ``patch_size = 6``
- ``embed_dim = 128``
- ``depths = [2, 4, 2]``
- ``num_heads = [4, 8, 16]``
404 changes: 404 additions & 0 deletions extras/cyclic_shift.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
Loading

0 comments on commit bfc68bb

Please sign in to comment.