-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Applied improvements to network layers to simplify the execution graph and increase robustness. Expanded tests and included a notebook to visualize SW-MSA.
- Loading branch information
Showing
9 changed files
with
1,153 additions
and
705 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"python.testing.unittestArgs": [ | ||
"-v", | ||
"-s", | ||
"./tests", | ||
"-p", | ||
"test_*.py" | ||
], | ||
"python.testing.pytestEnabled": false, | ||
"python.testing.unittestEnabled": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,151 +4,127 @@ | |
|
||
![Swin Transformer architecture](https://github.com/microsoft/Swin-Transformer/blob/3b0685bf2b99b4cf5770e47260c0f0118e6ff1bb/figures/teaser.png) | ||
|
||
This is a TensorFlow 2.0 implementation of the [Swin Transformer architecture](https://arxiv.org/abs/2103.14030). | ||
This is a Kears/TensorFlow 2.0 implementation of the [Swin Transformer architecture](https://arxiv.org/abs/2103.14030) inspired by the official Pytorch [code](https://github.com/microsoft/Swin-Transformer). | ||
|
||
It is built using the Keras API following best practices, such as allowing complete serialization and deserialization of custom layers and deferring weight creation until the first call with real inputs. | ||
|
||
This implementation is inspired by the [official version](https://github.com/microsoft/Swin-Transformer) offered by authors of the paper, while simultaneously improving in some areas such as shape and type checks. | ||
|
||
## Installation | ||
|
||
Clone the repository: | ||
```bash | ||
git clone [email protected]:MidnessX/swin.git | ||
``` | ||
Enter into the directory: | ||
Enter into it: | ||
```bash | ||
cd swin | ||
``` | ||
Install the package via: | ||
```bash | ||
pip install -e . | ||
pip install swin-transformer | ||
``` | ||
|
||
## Usage | ||
|
||
Class ``Swin`` in ``swin.model`` is a subclass of ``tf.keras.Model``, so you can instantiate Swin Transformers and train them through well known interface methods, such as ``compile()``, ``fit()``, ``save()``. | ||
|
||
The only remark is the first argument to the ``Swin`` class constructor, which is expected to be a ``tf.Tensor`` object or equivalent, such as a symbolic tensor produced by ``tf.keras.Input``. | ||
This tensor is only used to determine the shape of future inputs and can be an example coming from your dataset or any random tensor sharing its shape. | ||
|
||
For convenience, ``swin.model`` also includes classes for variants of the Swin architecture described in the article (``SwinT``, ``SwinS``, ``SwinB``, ``SwinL``) which initialize a ``Swin`` object with the variant's parameters. | ||
|
||
## Example | ||
|
||
```python | ||
import tensorflow as tf | ||
|
||
from swin.model import SwinT | ||
|
||
# Load the dataset as a list of mini batches | ||
train_x = ... | ||
train_y = ... | ||
num_classes = ... | ||
import tensorflow.keras as keras | ||
from swin import Swin | ||
|
||
# Take a mini batch from the dataset to build the model | ||
mini_batch = train_x[0] | ||
# Dataset loading, omitted for brevity | ||
x = [...] | ||
y = [...] | ||
num_classes = [...] | ||
|
||
model = SwinT(mini_batch, num_classes) | ||
model = Swin(num_classes) | ||
|
||
# Build the model by calling it for the first time | ||
model(mini_batch) | ||
|
||
# Compile the model | ||
model.compile( | ||
loss=tf.keras.losses.SGD(learning_rate=1e-3, momentum=0.9), | ||
optimizer=tf.keras.optimizers.CategoricalCrossentropy(), | ||
metrics=[tf.keras.metrics.CategoricalAccuracy()] | ||
optimizer=keras.optimizers.AdamW(), | ||
loss=keras.losses.CategoricalCrossentropy(), | ||
metrics=[keras.metrics.CategoricalAccuracy()] | ||
) | ||
|
||
# Train the model | ||
history = model.fit(train_x, train_y, epochs=300) | ||
model.fit( | ||
x, | ||
y, | ||
epochs=1000, | ||
) | ||
|
||
# Save the trained model | ||
model.save("path/to/model/directory") | ||
``` | ||
|
||
## Notes | ||
|
||
- The input type accepted by the model is ``tf.float32``. Any pre-processing of data should include a conversion step of images from ``tf.uint8`` to ``tf.float32`` if necessary. | ||
|
||
- Swin architectures have many parameters, so training them is not an easy task. Expect a lot of trial & error before honing in correct hyperparameters. | ||
|
||
- ``SwinModule`` layers place the dimensionality reduction layer (``SwinPatchMerging``) after transformer layers (``SwinTransformer``), rather than before as found in the paper. This choice is to maintain consistency with the original network implementation. | ||
|
||
## Testing | ||
|
||
Test modules can be found under the ``tests`` folder of this repository. | ||
They can be executed to test the expected functionality of custom layers for the Swin architecture, as well as basic functionalities of the whole model. | ||
|
||
Admittedly these tests could be expanded and further improved to cover more cases, but they should be enough to verify general functionality. | ||
|
||
## Assumptions and simplifications | ||
|
||
While implementing the Swin Transformer architecture a number of assumptions and simplifications have been made: | ||
|
||
1. Input images must have 3 channels. | ||
This network has been built to be consistent with its [official Pytorch implementation](https://github.com/microsoft/Swin-Transformer). | ||
This translates into the following statements: | ||
|
||
2. The size of windows in (Shifted) Windows Multi-head Attention is fixed to 7[^1]. | ||
- The ratio of hidden to output neurons in MLP blocks is set to 4. | ||
- Projection of input data to obtain `Q`, `K`, and `V` includes a bias term in all transformer blocks. | ||
- `Q` and `K` are scaled by `sqrt(d)`, where `d` is the size of `Q` and `K`. | ||
- No _Dropout_ is applied to attention heads. | ||
- [_Stochastic Depth_](https://arxiv.org/pdf/1603.09382.pdf) is applied to randomly disable patches after the attention computation, with probability set to 10%. | ||
- No absolute position information is added to embeddings. | ||
- _Layer Normalizaton_ is applied to embeddings. | ||
- The extraction of patches from images and the generation of embeddings both happen in the `SwinPatchEmbeddings` layer. | ||
- Patch merging happens at the end of each stage, rather than at the beginning. | ||
This simplifies the definition of layers and does not change the overall architecture. | ||
|
||
3. The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4[^1]. | ||
Additionally, the following decisions have been made to simplify development: | ||
|
||
4. A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention[^2]. | ||
|
||
5. ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``[^1]. | ||
|
||
6. No dropout is applied to attention heads[^2]. | ||
|
||
7. The probability of the Stochastic Depth computation-skipping technique during training is fixed to 0.1[^2]. | ||
|
||
8. No absolute position information is included in embeddings[^3]. | ||
|
||
9. ``LayerNormalization`` is applied after building patch embeddings[^2]. | ||
|
||
[^1]: To stay consistent with the content of the paper. | ||
|
||
[^2]: In the original implementation this happens when using default arguments. | ||
|
||
[^3]: Researchers note in the paper that adding absolute position information to embedding decreases network capabilities. | ||
- The network only accepts square `tf.float32` images with 3 channels as inputs (i.e. height and width must be identical). | ||
- No padding is applied to embeddings during the SW-MSA calculation, as their size is assumed to be a multiple of window size. | ||
|
||
## Choosing parameters | ||
|
||
### Dependencies | ||
|
||
If using the base class (``Swin``), it is necessary to provide a series of parameters to instantiate the model. | ||
The choice of these values is important and a series of dependencies exist between them. | ||
When using any of the subclasses (``SwinT``, ``SwinS``, ``SwinB``, ``SwinL``), the architecture is fixed to their respective variants found in the paper. | ||
|
||
The size of windows (``window_size``) used during (Shifted) Windows Multi-head Self Attention is the starting point and, as stated in the section about [assumptions](https://github.com/MidnessX/swin#assumptions-and-simplifications), it is fixed to ``7`` (as in the original paper). | ||
When using the `Swin` class directly, however, you can customize the resulting architecture by specifing all the network's parameters. | ||
This sections provides an overview of the dependencies existing between these parameters. | ||
|
||
The resolution of inputs to network stages, expressed as the number of patches along each axis, must be a multiple of ``window_size`` and gets halved by every stage through ``SwinPatchMerging`` layers. | ||
The suggestion is to choose a resolution for the final stage and multiply it by ``2`` for every stage in the desired model, obtaining the input resolution of the first stage (``resolution_stage_1``). | ||
- Each stage has an input with shape `(batch_size, num_patches, num_patches, embed_dim)`. | ||
`num_patches` must be a multiple of `window_size`. | ||
- Each stage halves the `num_patches` dimension by merging four adjacent patches together. | ||
It can be easier to choose a desired number of patches in the last stage and multiply it by 2 for every stage in the network to obtain the initial `num_patches` value. | ||
- By multiplying `num_patches` by `patch_size` you can find out the size in pixels of input images. | ||
- `embed_dim` must be a multiple of `num_heads` for every stage. | ||
- The number of transformer blocks in each stage can be set freely, as they do not alter the shape of patches. | ||
|
||
Input images to the ``Swin`` model must be squares, with their height/width given by multiplying ``resolution_stage_1`` with the desired size of patches (``patch_size``). | ||
To better understand how to choose network parameters, consider the following example: | ||
|
||
The number of ``SwinTransformer`` layers in each stage (``depths``) is arbitrary. | ||
1. The depth is set to 3 stages. | ||
2. Windows are set to be 8 patches wide (i.e. `window_size = 8`). | ||
3. The last stage should have a `2 * window_size = 16` patch-wide input. | ||
This means that the input to the second stage and the first stage will be 32x32 and 64x64 patch-wide respectively. | ||
4. We require each patch to cover a 6x6 pixel area, so the initial image will be `num_patches * 6 = 64 * 6 = 384` pixel wide. | ||
5. For the first stage, we choose 2 attention layers; 4 for the second, and 2 for the third. | ||
6. The number of attention heads is set to 4. | ||
This implies that there will be 8 attention heads in the second stage and 16 attention heads in the third stage. | ||
7. Using the value found in the Swin paper, the `embed_dim / num_heads` ratio is set to 32, leading to an initial `embed_dim` of `32 * 4 = 128`. | ||
|
||
The number of transformer heads (``num_heads``) should instead double at each stage. | ||
Authors of the paper use a fixed ratio between embedding dimensions and the number of heads in each stage of the network, amounting to ``32``. | ||
This means that, chosen the number of transformer heads in the first stage, it should be multiplied by ``32`` to obtain ``embed_dim``. | ||
Summarizing, this is equal to: | ||
|
||
The following example should help clarify these concepts. | ||
- `image_size = 384` | ||
- `patch_size = 6` | ||
- `window_size = 8` | ||
- `embed_dim = 128` | ||
- `depths = [2, 4, 2]` | ||
- `num_heads = [4, 8, 16]` | ||
|
||
### Parameter choice example | ||
|
||
Let's imagine we want a Swin Transformer having ``3`` stages. | ||
The last stage (``stage_3``) should receive inputs of ``14x14`` patches (``14 = window_size * 2``); this also means that ``stage_2`` receives inputs of ``28x28`` patches and ``stage_1`` of ``56x56``. | ||
## Testing | ||
|
||
We want to convert our images into patches having size ``6x6``, so images should have size ``56 * 6 = 336``. | ||
Test modules can be found under the ``tests`` folder of this repository. | ||
They can be executed to verify the expected functionality of custom layers for the Swin architecture, as well as basic functionalities of the whole model. | ||
|
||
Our network will have ``2`` transformers in the first stage, ``4`` in the second and ``2`` in the third. | ||
We choose ``4`` heads for the first stage and thus the second one will have ``8`` heads while the third ``16``. | ||
You can run them with the following command: | ||
```bash | ||
python -m unittest discover -s ./tests | ||
``` | ||
|
||
With these numbers we can derive the size of embeddings used in the first stage by multiplying ``32`` by ``4``, giving us ``128``. | ||
## Extras | ||
|
||
Summarizing, we have: | ||
To better understand how SW-MSA works, a Jupyter notebook found in the `extras` folder can be used to visualize window partitioning, traslation and mask construction. | ||
|
||
- ``image_size = 336`` | ||
- ``patch_size = 6`` | ||
- ``embed_dim = 128`` | ||
- ``depths = [2, 4, 2]`` | ||
- ``num_heads = [4, 8, 16]`` |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[build-system] | ||
requires = ["setuptools", "wheel"] | ||
requires = ["setuptools"] | ||
build-backend = "setuptools.build_meta" |
Oops, something went wrong.