diff --git a/.gitignore b/.gitignore index b26562d..2dafb2a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,22 @@ + +# Created by https://www.toptal.com/developers/gitignore/api/linux,visualstudiocode,python +# Edit at https://www.toptal.com/developers/gitignore?templates=linux,visualstudiocode,python + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ @@ -185,4 +204,4 @@ cython_debug/ # Ignore code-workspaces *.code-workspace -# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode \ No newline at end of file +# End of https://www.toptal.com/developers/gitignore/api/linux,visualstudiocode,python \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e9e6a80 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} \ No newline at end of file diff --git a/README.md b/README.md index 0445dec..34935de 100644 --- a/README.md +++ b/README.md @@ -4,151 +4,127 @@ ![Swin Transformer architecture](https://github.com/microsoft/Swin-Transformer/blob/3b0685bf2b99b4cf5770e47260c0f0118e6ff1bb/figures/teaser.png) -This is a TensorFlow 2.0 implementation of the [Swin Transformer architecture](https://arxiv.org/abs/2103.14030). +This is a Kears/TensorFlow 2.0 implementation of the [Swin Transformer architecture](https://arxiv.org/abs/2103.14030) inspired by the official Pytorch [code](https://github.com/microsoft/Swin-Transformer). It is built using the Keras API following best practices, such as allowing complete serialization and deserialization of custom layers and deferring weight creation until the first call with real inputs. -This implementation is inspired by the [official version](https://github.com/microsoft/Swin-Transformer) offered by authors of the paper, while simultaneously improving in some areas such as shape and type checks. - ## Installation Clone the repository: ```bash git clone git@github.com:MidnessX/swin.git ``` -Enter into the directory: +Enter into it: ```bash cd swin ``` Install the package via: ```bash -pip install -e . +pip install swin-transformer ``` ## Usage Class ``Swin`` in ``swin.model`` is a subclass of ``tf.keras.Model``, so you can instantiate Swin Transformers and train them through well known interface methods, such as ``compile()``, ``fit()``, ``save()``. -The only remark is the first argument to the ``Swin`` class constructor, which is expected to be a ``tf.Tensor`` object or equivalent, such as a symbolic tensor produced by ``tf.keras.Input``. -This tensor is only used to determine the shape of future inputs and can be an example coming from your dataset or any random tensor sharing its shape. - For convenience, ``swin.model`` also includes classes for variants of the Swin architecture described in the article (``SwinT``, ``SwinS``, ``SwinB``, ``SwinL``) which initialize a ``Swin`` object with the variant's parameters. ## Example ```python -import tensorflow as tf - -from swin.model import SwinT - -# Load the dataset as a list of mini batches -train_x = ... -train_y = ... -num_classes = ... +import tensorflow.keras as keras +from swin import Swin -# Take a mini batch from the dataset to build the model -mini_batch = train_x[0] +# Dataset loading, omitted for brevity +x = [...] +y = [...] +num_classes = [...] -model = SwinT(mini_batch, num_classes) +model = Swin(num_classes) -# Build the model by calling it for the first time -model(mini_batch) - -# Compile the model model.compile( - loss=tf.keras.losses.SGD(learning_rate=1e-3, momentum=0.9), - optimizer=tf.keras.optimizers.CategoricalCrossentropy(), - metrics=[tf.keras.metrics.CategoricalAccuracy()] + optimizer=keras.optimizers.AdamW(), + loss=keras.losses.CategoricalCrossentropy(), + metrics=[keras.metrics.CategoricalAccuracy()] ) -# Train the model -history = model.fit(train_x, train_y, epochs=300) +model.fit( + x, + y, + epochs=1000, +) -# Save the trained model model.save("path/to/model/directory") ``` ## Notes -- The input type accepted by the model is ``tf.float32``. Any pre-processing of data should include a conversion step of images from ``tf.uint8`` to ``tf.float32`` if necessary. - -- Swin architectures have many parameters, so training them is not an easy task. Expect a lot of trial & error before honing in correct hyperparameters. - -- ``SwinModule`` layers place the dimensionality reduction layer (``SwinPatchMerging``) after transformer layers (``SwinTransformer``), rather than before as found in the paper. This choice is to maintain consistency with the original network implementation. - -## Testing - -Test modules can be found under the ``tests`` folder of this repository. -They can be executed to test the expected functionality of custom layers for the Swin architecture, as well as basic functionalities of the whole model. - -Admittedly these tests could be expanded and further improved to cover more cases, but they should be enough to verify general functionality. - -## Assumptions and simplifications - -While implementing the Swin Transformer architecture a number of assumptions and simplifications have been made: - -1. Input images must have 3 channels. +This network has been built to be consistent with its [official Pytorch implementation](https://github.com/microsoft/Swin-Transformer). +This translates into the following statements: -2. The size of windows in (Shifted) Windows Multi-head Attention is fixed to 7[^1]. +- The ratio of hidden to output neurons in MLP blocks is set to 4. +- Projection of input data to obtain `Q`, `K`, and `V` includes a bias term in all transformer blocks. +- `Q` and `K` are scaled by `sqrt(d)`, where `d` is the size of `Q` and `K`. +- No _Dropout_ is applied to attention heads. +- [_Stochastic Depth_](https://arxiv.org/pdf/1603.09382.pdf) is applied to randomly disable patches after the attention computation, with probability set to 10%. +- No absolute position information is added to embeddings. +- _Layer Normalizaton_ is applied to embeddings. +- The extraction of patches from images and the generation of embeddings both happen in the `SwinPatchEmbeddings` layer. +- Patch merging happens at the end of each stage, rather than at the beginning. + This simplifies the definition of layers and does not change the overall architecture. -3. The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4[^1]. +Additionally, the following decisions have been made to simplify development: -4. A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention[^2]. - -5. ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``[^1]. - -6. No dropout is applied to attention heads[^2]. - -7. The probability of the Stochastic Depth computation-skipping technique during training is fixed to 0.1[^2]. - -8. No absolute position information is included in embeddings[^3]. - -9. ``LayerNormalization`` is applied after building patch embeddings[^2]. - -[^1]: To stay consistent with the content of the paper. - -[^2]: In the original implementation this happens when using default arguments. - -[^3]: Researchers note in the paper that adding absolute position information to embedding decreases network capabilities. +- The network only accepts square `tf.float32` images with 3 channels as inputs (i.e. height and width must be identical). +- No padding is applied to embeddings during the SW-MSA calculation, as their size is assumed to be a multiple of window size. ## Choosing parameters -### Dependencies - -If using the base class (``Swin``), it is necessary to provide a series of parameters to instantiate the model. -The choice of these values is important and a series of dependencies exist between them. +When using any of the subclasses (``SwinT``, ``SwinS``, ``SwinB``, ``SwinL``), the architecture is fixed to their respective variants found in the paper. -The size of windows (``window_size``) used during (Shifted) Windows Multi-head Self Attention is the starting point and, as stated in the section about [assumptions](https://github.com/MidnessX/swin#assumptions-and-simplifications), it is fixed to ``7`` (as in the original paper). +When using the `Swin` class directly, however, you can customize the resulting architecture by specifing all the network's parameters. +This sections provides an overview of the dependencies existing between these parameters. -The resolution of inputs to network stages, expressed as the number of patches along each axis, must be a multiple of ``window_size`` and gets halved by every stage through ``SwinPatchMerging`` layers. -The suggestion is to choose a resolution for the final stage and multiply it by ``2`` for every stage in the desired model, obtaining the input resolution of the first stage (``resolution_stage_1``). +- Each stage has an input with shape `(batch_size, num_patches, num_patches, embed_dim)`. +`num_patches` must be a multiple of `window_size`. +- Each stage halves the `num_patches` dimension by merging four adjacent patches together. + It can be easier to choose a desired number of patches in the last stage and multiply it by 2 for every stage in the network to obtain the initial `num_patches` value. +- By multiplying `num_patches` by `patch_size` you can find out the size in pixels of input images. +- `embed_dim` must be a multiple of `num_heads` for every stage. +- The number of transformer blocks in each stage can be set freely, as they do not alter the shape of patches. -Input images to the ``Swin`` model must be squares, with their height/width given by multiplying ``resolution_stage_1`` with the desired size of patches (``patch_size``). +To better understand how to choose network parameters, consider the following example: -The number of ``SwinTransformer`` layers in each stage (``depths``) is arbitrary. +1. The depth is set to 3 stages. +2. Windows are set to be 8 patches wide (i.e. `window_size = 8`). +3. The last stage should have a `2 * window_size = 16` patch-wide input. + This means that the input to the second stage and the first stage will be 32x32 and 64x64 patch-wide respectively. +4. We require each patch to cover a 6x6 pixel area, so the initial image will be `num_patches * 6 = 64 * 6 = 384` pixel wide. +5. For the first stage, we choose 2 attention layers; 4 for the second, and 2 for the third. +6. The number of attention heads is set to 4. + This implies that there will be 8 attention heads in the second stage and 16 attention heads in the third stage. +7. Using the value found in the Swin paper, the `embed_dim / num_heads` ratio is set to 32, leading to an initial `embed_dim` of `32 * 4 = 128`. -The number of transformer heads (``num_heads``) should instead double at each stage. -Authors of the paper use a fixed ratio between embedding dimensions and the number of heads in each stage of the network, amounting to ``32``. -This means that, chosen the number of transformer heads in the first stage, it should be multiplied by ``32`` to obtain ``embed_dim``. +Summarizing, this is equal to: -The following example should help clarify these concepts. +- `image_size = 384` +- `patch_size = 6` +- `window_size = 8` +- `embed_dim = 128` +- `depths = [2, 4, 2]` +- `num_heads = [4, 8, 16]` -### Parameter choice example - -Let's imagine we want a Swin Transformer having ``3`` stages. -The last stage (``stage_3``) should receive inputs of ``14x14`` patches (``14 = window_size * 2``); this also means that ``stage_2`` receives inputs of ``28x28`` patches and ``stage_1`` of ``56x56``. +## Testing -We want to convert our images into patches having size ``6x6``, so images should have size ``56 * 6 = 336``. +Test modules can be found under the ``tests`` folder of this repository. +They can be executed to verify the expected functionality of custom layers for the Swin architecture, as well as basic functionalities of the whole model. -Our network will have ``2`` transformers in the first stage, ``4`` in the second and ``2`` in the third. -We choose ``4`` heads for the first stage and thus the second one will have ``8`` heads while the third ``16``. +You can run them with the following command: +```bash +python -m unittest discover -s ./tests +``` -With these numbers we can derive the size of embeddings used in the first stage by multiplying ``32`` by ``4``, giving us ``128``. +## Extras -Summarizing, we have: +To better understand how SW-MSA works, a Jupyter notebook found in the `extras` folder can be used to visualize window partitioning, traslation and mask construction. -- ``image_size = 336`` -- ``patch_size = 6`` -- ``embed_dim = 128`` -- ``depths = [2, 4, 2]`` -- ``num_heads = [4, 8, 16]`` \ No newline at end of file diff --git a/extras/cyclic_shift.ipynb b/extras/cyclic_shift.ipynb new file mode 100644 index 0000000..79ed558 --- /dev/null +++ b/extras/cyclic_shift.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cyclic shifting and SW-MSA masking\n", + "\n", + "This notebook is meant to help understand what happens during cycling shifting of patches in the computation of Shifted Windows Multi-head Self Attention." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-07-05 16:36:18.110086: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2023-07-05 16:36:18.136710: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.lines as lines\n", + "from math import ceil, floor\n", + "from swin.modules import SwinTransformer\n", + "import tensorflow as tf\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define a bunch of constants which determine the size of the input feature map to the SW-MSA." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "patch_size = 2 # in px\n", + "map_size = 8 # in px\n", + "n_patches = ceil(map_size / patch_size)\n", + "\n", + "res = int(map_size / patch_size) # in # of patches\n", + "window_size = 2 # in # of patches\n", + "shift_size = 1 # in # of patches\n", + "\n", + "windows_res = int(map_size / (window_size * patch_size)) # in windows" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first plot the feature map with windows overlaid on top.\n", + "Red lines denote the windows without any shifting applied, while cyan lines denote windows after shifting.\n", + "\n", + "We then plot the same feature map after having applied cyclic shift.\n", + "We can see that patches are simultaneously moved to the top and to the left, with those overflowing ending at the bottom and to the right.\n", + "Again, red lines denote the windows in which the feature map will be split into.\n", + "\n", + "We can clearly see that, after shifting, windows correspond to those painted in Cyan in the first image.\n", + "It's also worth mentioning that some windows are made of patches which were not adjacent in the original feature map.\n", + "This reason is why we need a mask during SW-MSA." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "patches_cm = mpl.colors.ListedColormap(plt.cm.Blues(np.linspace(0, 1, n_patches**2)), \"Patches colormap\")\n", + "\n", + "fm = np.zeros((map_size, map_size))\n", + "\n", + "color = 0\n", + "for x in range(n_patches):\n", + " for y in range(n_patches):\n", + " # print(f\"{x*patch_size}:{(x+1)*patch_size}, {y*patch_size}:{(y+1)*patch_size} -> {color}\")\n", + " fm[x*patch_size:(x+1)*patch_size,y*patch_size:(y+1)*patch_size] = color\n", + " color += 1\n", + "\n", + "shifted_fm = np.roll(\n", + " fm, shift=[-shift_size * patch_size, -shift_size * patch_size], axis=[0, 1]\n", + ")\n", + "\n", + "fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)\n", + "for i, data in enumerate([fm, shifted_fm]):\n", + " im = axes[i].matshow(data, cmap=patches_cm, extent=[0, map_size, map_size, 0])\n", + " for j in range(floor(map_size / (window_size * patch_size))):\n", + " axes[i].add_artist(\n", + " lines.Line2D(\n", + " [\n", + " window_size * patch_size * j,\n", + " window_size * patch_size * j\n", + " ],\n", + " [\n", + " 0,\n", + " map_size\n", + " ],\n", + " color=\"Red\"\n", + " )\n", + " )\n", + " axes[i].add_artist(\n", + " lines.Line2D(\n", + " [\n", + " 0,\n", + " map_size\n", + " ],\n", + " [\n", + " window_size * patch_size * j,\n", + " window_size * patch_size * j\n", + " ],\n", + " color=\"Red\"\n", + " )\n", + " )\n", + "\n", + " if i == 0: # Only draw shifted windows boundaries on the first image\n", + " axes[i].add_artist(\n", + " lines.Line2D(\n", + " [\n", + " window_size * patch_size * j + shift_size * patch_size,\n", + " window_size * patch_size * j + shift_size * patch_size\n", + " ],\n", + " [\n", + " 0 + shift_size * patch_size,\n", + " map_size\n", + " ],\n", + " color=\"Cyan\"\n", + " )\n", + " )\n", + " axes[i].add_artist(\n", + " lines.Line2D(\n", + " [\n", + " 0 + shift_size * patch_size,\n", + " map_size\n", + " ],\n", + " [\n", + " window_size * patch_size * j + shift_size * patch_size,\n", + " window_size * patch_size * j + shift_size * patch_size\n", + " ],\n", + " color=\"Cyan\"\n", + " )\n", + " )\n", + "fig.colorbar(im, ax=axes.ravel().tolist())" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now proceed to build and visualize the mask which will be used during SW-MSA.\n", + "This mask is to prevent patches to pay attention to other patches in the same window which are not really adjacent to each other due to the cyclic shift applied earlier.\n", + "\n", + "The first image depicts, step by step, how patches in each window get a value assigned denoting their origin.\n", + "For example, the top-left-most window only contains patches coming from the same window which gets the value 0 assigned.\n", + "On the other hand, the bottom-right-most windows is made of patches coming from a variety of windows, each getting a different value assigned." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mask = np.zeros([1, res, res, 1])\n", + "\n", + "h_slices = (\n", + " slice(0, -window_size),\n", + " slice(-window_size, -shift_size),\n", + " slice(-shift_size, None),\n", + ")\n", + "w_slices = (\n", + " slice(0, -window_size),\n", + " slice(-window_size, -shift_size),\n", + " slice(-shift_size, None),\n", + ")\n", + "\n", + "windows_cm = mpl.colors.ListedColormap(\n", + " plt.cm.Reds(np.linspace(0, 1, len(h_slices) * len(w_slices))), \n", + " \"Patches colormap\"\n", + ")\n", + "fig, axes = plt.subplots(\n", + " ceil(len(h_slices) * len(w_slices) / 3), \n", + " 3, \n", + " sharex=True, \n", + " sharey=True\n", + ")\n", + "\n", + "i = 0\n", + "for h_slice in h_slices:\n", + " for w_slice in w_slices:\n", + " mask[:, h_slice, w_slice, :] = i\n", + "\n", + " im = axes[floor(i / 3), i % 3].matshow(\n", + " mask[0, :, :, 0], \n", + " cmap=windows_cm, \n", + " extent=[0, res, res, 0], \n", + " vmin=0, \n", + " vmax=len(windows_cm.colors)\n", + " )\n", + "\n", + " i += 1\n", + "\n", + "fig.colorbar(im, ax=axes.ravel().tolist())" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This mask cannot be used as is, so we need to do some further processing.\n", + "\n", + "The first step is to split it into windows, as can be seen in the first figure.\n", + "\n", + "We must remember that the self-attention mechanism performs the calculation between each pair of patches in a window.\n", + "For pair of patches that were distant prior to cyclic shifting we want the attention value to be negative, so that it will result in a value close to 0 when put through the SoftMax operation.\n", + "This can be achieved by summing the attention matrix with a mask matrix having cells with large negative values for those pairs of cells distant from each other.\n", + "\n", + "Currently, however, (1) our mask does not have the right shape and the (2) right values.\n", + "To fix it, we need to do the following:\n", + "\n", + "1.\n", + " - We flatten each window by concatenating its patches along a single axis, as can be seen in the second figure.\n", + " - We broadcast this flattened vector into a square matrix by repeating its values for each row. We do the same with the transposed vector, turning it into a matrix by repeating its values for each column.\n", + " - We subtract the first matrix to the second: this yields a matrix with zeros where the two subtracted values were identical and other numbers where values were different. 0 denotes a pair of patches which are adjacent and for which the computation of attention is valid.\n", + "2. We transform every non-zero value into a big negative value (in this case, -100).\n", + "\n", + "Our final mask can be seen in the third figure.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mask_windows = SwinTransformer.window_partition(mask, window_size) # (n_windows * 1, window_size, window_size, 1)\n", + "fig, axes = plt.subplots(windows_res, windows_res)\n", + "for i in range(windows_res):\n", + " for j in range(windows_res):\n", + " im = axes[i, j].matshow(\n", + " mask_windows.numpy()[i + j + (i * (axes.shape[1] - 1)), :, :, 0], \n", + " cmap=windows_cm, \n", + " extent=[0, res, res, 0], \n", + " vmin=0, \n", + " vmax=len(windows_cm.colors)\n", + " )\n", + "fig.colorbar(im, ax=axes.ravel().tolist())\n", + "\n", + "mask_windows = tf.reshape(mask_windows, [-1, window_size * window_size])\n", + "\n", + "fig, axes = plt.subplots(mask_windows.shape[0])\n", + "for i in range(axes.shape[0]):\n", + " im = axes[i].matshow(\n", + " mask_windows.numpy()[i].reshape((1, -1)), \n", + " cmap=windows_cm,\n", + " extent=[0, res, res, 0], \n", + " vmin=0, \n", + " vmax=len(windows_cm.colors),\n", + " )\n", + "fig.colorbar(im, ax=axes.ravel().tolist())\n", + "\n", + "attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)\n", + "\n", + "attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)\n", + "fig, axes = plt.subplots(windows_res, windows_res)\n", + "for i in range(windows_res):\n", + " for j in range(windows_res):\n", + " im = axes[i, j].matshow(\n", + " attn_mask.numpy()[i + j + (i * (axes.shape[1] - 1))],\n", + " extent=[0, res, res, 0],\n", + " cmap=\"Greens\",\n", + " vmin=-100,\n", + " vmax=0\n", + " )\n", + "fig.colorbar(im, ax=axes.ravel().tolist())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 07de284..7fd26b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = ["setuptools"] build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/src/swin/model.py b/src/swin/model.py index aaee273..443d638 100644 --- a/src/swin/model.py +++ b/src/swin/model.py @@ -1,21 +1,34 @@ -"""The Swin model definition module.""" +"""The Swin model definition module. + +Attributes: + DEFAULT_DROP_RATE: Default probability of dropping connections in + ``Dropout`` layers. + DEFAULT_DROP_PATH_RATE: Default maximum probability of entirely skipping a + (Shifted) Windows Multi-head Attention computation (Stochastic Depth + computation-skipping technique) during training. + This maximum value is used in the last stage of the network, while + previous stages use linearly spaced values in the + [0 ,``drop_path_rate``] interval. +""" + +import collections.abc import numpy as np import tensorflow as tf import tensorflow_addons as tfa -import collections.abc -from swin.modules import SwinPatchEmbeddings, SwinStage, SwinLinear +from swin.modules import SwinLinear, SwinPatchEmbeddings, SwinStage + +DEFAULT_DROP_RATE: float = 0.0 +DEFAULT_DROP_PATH_RATE: float = 0.1 class Swin(tf.keras.Model): """Swin Transformer Model. - Some assumptions have been made about this model: + To stay consistent with the architecture described in the paper, this class + assumes the following: - - ``inputs`` must always be a color image (3 channels). - - The size of windows in (Shifted) Windows Multi-head Attention is fixed - to 7. - The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4. - A learnable bias is added to ``queries``, ``keys`` and ``values`` @@ -23,57 +36,54 @@ class Swin(tf.keras.Model): - ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``. - No dropout is applied to attention heads. - - The probability of the Stochastic Depth computation-skipping technique - during training is fixed to 0.1. - No absolute position information is included in embeddings. - ``LayerNormalization`` is applied after building patch embeddings. Args: - inputs: The input to be expected by the model. It must describe a batch - of images in the ``channels_last`` format. Images must have height - equal to width (they must be square images). num_classes: The number of classes to predict. It determines the dimension of the output tensor. - patch_size: The size of each patch in which images will be divided into. - embed_dim: The lenght of embeddings built from patches. + patch_size: The size of patches in which images will be divided into. + Expressed in pixels. + window_size: The size of windows in (Shifted) Windows Multi-head + Attention layers expressed in patches per axis. + embed_dim: The length of embeddings built from patches. depths: The number of ``SwinTransformer`` layers in each stage of the network. num_heads: The number of (Shifted) Windows Multi-head Attention heads in each stage of the network. drop_rate: The probability of dropping connections in ``Dropout`` layers. + drop_path_rate: The maximum probability of entirely skipping a (Shifted) + Windows Multi-head Attention computation (Stochastic Depth + computation-skipping technique) during training. + This maximum value is used in the last stage of the network, while + previous stages use linearly spaced values in the + [0 ,``drop_path_rate``] interval. """ def __init__( self, - inputs: tf.Tensor, num_classes: int, patch_size: int = 4, + window_size: int = 7, embed_dim: int = 96, depths: collections.abc.Collection[int] = (2, 2, 6, 2), num_heads: collections.abc.Collection[int] = (3, 6, 12, 24), - drop_rate: float = 0.0, + drop_rate: float = DEFAULT_DROP_RATE, + drop_path_rate: float = DEFAULT_DROP_PATH_RATE, **kwargs, ) -> None: super().__init__(**kwargs) - assert inputs.dtype == tf.float32 - assert inputs.shape[1] == inputs.shape[2] and inputs.shape[3] == 3 - - self.input_shape_list = [ - inputs.shape[0], - inputs.shape[1], - inputs.shape[2], - inputs.shape[3], - ] # When returning this model's config, we only need axes' shapes, not the whole input tensor self.num_classes = num_classes self.patch_size = patch_size + self.window_size = window_size self.embed_dim = embed_dim self.depths = depths self.num_layers = len(self.depths) self.num_heads = num_heads self.drop_rate = drop_rate - self.drop_path_rate = 0.1 + self.drop_path_rate = drop_path_rate self.patch_embed = SwinPatchEmbeddings( self.embed_dim, @@ -81,7 +91,6 @@ def __init__( norm_layer=True, name="patches_linear_embedding", ) - self.patch_embed.compute_output_shape(inputs.shape) self.pos_drop = tf.keras.layers.Dropout(rate=self.drop_rate) @@ -90,7 +99,7 @@ def __init__( # These tensor would then get returned as layer parameters through # calls to their get_config() methods, causing problems in the JSON # serialization as the built-in Python library cannot handle this - # type of objects and thus preventing model saving. + # type of objects, thus preventing model saving. drop_depth_rate = [ x for x in np.linspace( @@ -103,10 +112,9 @@ def __init__( self.blocks = tf.keras.Sequential( [ SwinStage( - input_resolution=self.patch_embed.patches_resolution[0] // (2**i), depth=depths[i], num_heads=num_heads[i], - window_size=7, + window_size=self.window_size, mlp_ratio=4.0, drop_p=drop_rate, drop_path_p=drop_depth_rate[sum(depths[:i]) : sum(depths[: i + 1])], @@ -120,20 +128,24 @@ def __init__( self.norm = tf.keras.layers.LayerNormalization( epsilon=1e-5, name="layer_normalization" ) - self.avgpool = tfa.layers.AdaptiveAveragePooling1D( - 1, name="adaptive_average_pooling" + self.avgpool = tfa.layers.AdaptiveAveragePooling2D( + [1, 1], name="adaptive_average_pooling" ) self.flatten = tf.keras.layers.Flatten(name="flatten") self.head = SwinLinear(num_classes, name="classification_head") + def build(self, input_shape: tf.TensorShape) -> None: + assert input_shape.rank == 4 + assert input_shape[1] == input_shape[2] and input_shape[3] == 3 + def call(self, inputs, **kwargs): x = self.patch_embed(inputs, **kwargs) x = self.pos_drop(x, **kwargs) x = self.blocks(x, **kwargs) + x = self.norm(x, **kwargs) x = self.avgpool(x, **kwargs) x = self.flatten(x, **kwargs) - x = self.head(x, **kwargs) x = tf.nn.softmax(x) @@ -141,27 +153,20 @@ def call(self, inputs, **kwargs): def get_config(self) -> dict: config = { - "input_shape_list": self.input_shape_list, "num_classes": self.num_classes, "patch_size": self.patch_size, + "window_size": self.window_size, "embed_dim": self.embed_dim, "depths": self.depths, "num_heads": self.num_heads, "drop_rate": self.drop_rate, + "drop_path_rate": self.drop_path_rate, } return config - @classmethod - def from_config(cls, config: dict) -> "Swin": - # Since we only have the shape of the input, we build a new random tensor. - # Dtype is fixed to tf.float32. - inputs = tf.random.uniform(config.pop("input_shape_list"), dtype=tf.float32) - - return cls(inputs, **config) - def __repr__(self) -> str: - return f"{self.__class__.__name__}(patch_size={self.patch_size}, embed_dim={self.embed_dim}, depths={self.depths}, num_heads={self.num_heads}, drop_rate={self.drop_rate})" + return f"{self.__class__.__name__}(num_classes={self.num_classes}, patch_size={self.patch_size}, window_size={self.window_size}, embed_dim={self.embed_dim}, depths={self.depths}, num_heads={self.num_heads}, drop_rate={self.drop_rate}, drop_path_rate={self.drop_path_rate})" class SwinT(Swin): @@ -170,47 +175,53 @@ class SwinT(Swin): This version (tiny) uses the following options: - ``patch_size`` = 4 + - ``window_size`` = 7 - ``embed_dim`` = 96 - ``depths`` = (2, 2, 6, 2) - ``num_heads`` = (3, 6, 12, 24) - Some assumptions have been made about this model: + To stay consistent with the architecture described in the paper, this class + assumes the following: - - ``inputs`` must always be a coloured image (3 channels) - - The size of windows in (Shifted) Windows Multi-head Attention is fixed - to 7. - The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4. - A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention. - ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``. - - No dropout is applied to Attention heads. - - The probability of the Stochastic Depth technique is fixed to 0.1. + - No dropout is applied to attention heads. - No absolute position information is included in embeddings. - ``LayerNormalization`` is applied after building patch embeddings. Args: - inputs: The input to be expected by the model. It must describe a batch - of images in the ``channels_last`` format. Images must have height - equal to width (they must be square images). num_classes: The number of classes to predict. It determines the dimension of the output tensor. drop_rate: The probability of dropping connections in ``Dropout`` layers. + drop_path_rate: The maximum probability of entirely skipping a (Shifted) + Windows Multi-head Attention computation (Stochastic Depth + computation-skipping technique) during training. + This maximum value is used in the last stage of the network, while + previous stages use linearly spaced values in the + [0 ,``drop_path_rate``] interval. """ def __init__( - self, inputs: tf.Tensor, num_classes: int, drop_rate: float = 0, **kwargs + self, + num_classes: int, + drop_rate: float = DEFAULT_DROP_RATE, + drop_path_rate: float = DEFAULT_DROP_PATH_RATE, + **kwargs, ) -> None: super().__init__( - inputs, num_classes, patch_size=4, + window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), drop_rate=drop_rate, + drop_path_rate=drop_path_rate, **kwargs, ) @@ -221,47 +232,53 @@ class SwinS(Swin): This version (small) uses the following options: - ``patch_size`` = 4 + - ``window_size`` = 7 - ``embed_dim`` = 96 - ``depths`` = (2, 2, 18, 2) - ``num_heads`` = (3, 6, 12, 24) - Some assumptions have been made about this model: + To stay consistent with the architecture described in the paper, this class + assumes the following: - - ``inputs`` must always be a coloured image (3 channels) - - The size of windows in (Shifted) Windows Multi-head Attention is fixed - to 7. - The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4. - A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention. - ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``. - - No dropout is applied to Attention heads. - - The probability of the Stochastic Depth technique is fixed to 0.1. + - No dropout is applied to attention heads. - No absolute position information is included in embeddings. - ``LayerNormalization`` is applied after building patch embeddings. Args: - inputs: The input to be expected by the model. It must describe a batch - of images in the ``channels_last`` format. Images must have height - equal to width (they must be square images). num_classes: The number of classes to predict. It determines the dimension of the output tensor. drop_rate: The probability of dropping connections in ``Dropout`` layers. + drop_path_rate: The maximum probability of entirely skipping a (Shifted) + Windows Multi-head Attention computation (Stochastic Depth + computation-skipping technique) during training. + This maximum value is used in the last stage of the network, while + previous stages use linearly spaced values in the + [0 ,``drop_path_rate``] interval. """ def __init__( - self, inputs: tf.Tensor, num_classes: int, drop_rate: float = 0, **kwargs + self, + num_classes: int, + drop_rate: float = DEFAULT_DROP_RATE, + drop_path_rate: float = DEFAULT_DROP_PATH_RATE, + **kwargs, ) -> None: super().__init__( - inputs, num_classes, patch_size=4, + window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), drop_rate=drop_rate, + drop_path_rate=drop_path_rate, **kwargs, ) @@ -272,47 +289,53 @@ class SwinB(Swin): This version (base) uses the following options: - ``patch_size`` = 4 + - ``window_size`` = 7 - ``embed_dim`` = 128 - ``depths`` = (2, 2, 18, 2) - ``num_heads`` = (4, 8, 16, 32) - Some assumptions have been made about this model: + To stay consistent with the architecture described in the paper, this class + assumes the following: - - ``inputs`` must always be a coloured image (3 channels) - - The size of windows in (Shifted) Windows Multi-head Attention is fixed - to 7. - The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4. - A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention. - ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``. - - No dropout is applied to Attention heads. - - The probability of the Stochastic Depth technique is fixed to 0.1. + - No dropout is applied to attention heads. - No absolute position information is included in embeddings. - ``LayerNormalization`` is applied after building patch embeddings. Args: - inputs: The input to be expected by the model. It must describe a batch - of images in the ``channels_last`` format. Images must have height - equal to width (they must be square images). num_classes: The number of classes to predict. It determines the dimension of the output tensor. drop_rate: The probability of dropping connections in ``Dropout`` layers. + drop_path_rate: The maximum probability of entirely skipping a (Shifted) + Windows Multi-head Attention computation (Stochastic Depth + computation-skipping technique) during training. + This maximum value is used in the last stage of the network, while + previous stages use linearly spaced values in the + [0 ,``drop_path_rate``] interval. """ def __init__( - self, inputs: tf.Tensor, num_classes: int, drop_rate: float = 0, **kwargs + self, + num_classes: int, + drop_rate: float = DEFAULT_DROP_RATE, + drop_path_rate: float = DEFAULT_DROP_PATH_RATE, + **kwargs, ) -> None: super().__init__( - inputs, num_classes, patch_size=4, + window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), drop_rate=drop_rate, + drop_path_rate=drop_path_rate, **kwargs, ) @@ -323,46 +346,52 @@ class SwinL(Swin): This version (large) uses the following options: - ``patch_size`` = 4 + - ``window_size`` = 7 - ``embed_dim`` = 192 - ``depths`` = (2, 2, 18, 2) - ``num_heads`` = (6, 12, 24, 48) - Some assumptions have been made about this model: + To stay consistent with the architecture described in the paper, this class + assumes the following: - - ``inputs`` must always be a coloured image (3 channels) - - The size of windows in (Shifted) Windows Multi-head Attention is fixed - to 7. - The ratio of hidden to output neurons in ``SwinMlp`` layers is fixed to 4. - A learnable bias is added to ``queries``, ``keys`` and ``values`` when computing (Shifted) Window Multi-head Attention. - ``queries`` and ``keys`` are scaled by a factor of ``head_dimension**-0.5``. - - No dropout is applied to Attention heads. - - The probability of the Stochastic Depth technique is fixed to 0.1. + - No dropout is applied to attention heads. - No absolute position information is included in embeddings. - ``LayerNormalization`` is applied after building patch embeddings. Args: - inputs: The input to be expected by the model. It must describe a batch - of images in the ``channels_last`` format. Images must have height - equal to width (they must be square images). num_classes: The number of classes to predict. It determines the dimension of the output tensor. drop_rate: The probability of dropping connections in ``Dropout`` layers. + drop_path_rate: The maximum probability of entirely skipping a (Shifted) + Windows Multi-head Attention computation (Stochastic Depth + computation-skipping technique) during training. + This maximum value is used in the last stage of the network, while + previous stages use linearly spaced values in the + [0 ,``drop_path_rate``] interval. """ def __init__( - self, inputs: tf.Tensor, num_classes: int, drop_rate: float = 0, **kwargs + self, + num_classes: int, + drop_rate: float = DEFAULT_DROP_RATE, + drop_path_rate: float = DEFAULT_DROP_PATH_RATE, + **kwargs, ) -> None: super().__init__( - inputs, num_classes, patch_size=4, + window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), drop_rate=drop_rate, + drop_path_rate=drop_path_rate, **kwargs, ) diff --git a/src/swin/modules.py b/src/swin/modules.py index 0dd2d04..f510c34 100644 --- a/src/swin/modules.py +++ b/src/swin/modules.py @@ -1,6 +1,7 @@ """Modules used by the Swin Transformer.""" import collections.abc + import numpy as np import tensorflow as tf @@ -18,7 +19,7 @@ class SwinLinear(tf.keras.layers.Dense): use_bias: Whether the layer uses a bias vector. """ - def __init__(self, units: int, use_bias=True, **kwargs) -> None: + def __init__(self, units: int, use_bias: bool = True, **kwargs) -> None: super().__init__( units, activation=tf.keras.activations.linear, @@ -38,7 +39,7 @@ class SwinPatchEmbeddings(tf.keras.layers.Layer): Args: embed_dim: Dimension of output embeddings. - patch_size: Size of axes of image patches, expressed in pixels. + patch_size: Height/width of patches, expressed in pixels. norm_layer: Whether to apply layer normalization or not. """ @@ -79,25 +80,21 @@ def build(self, input_shape: tf.TensorShape) -> None: ) self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] - self.flatten = tf.keras.layers.Reshape((-1, self.embed_dim)) - def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """Build embeddings for every patch of the image. Args: - inputs: A batch of images with shape (batch_size, height, width, - channels). + inputs: A batch of images with shape ``(batch_size, height, width, channels)``. + ``height`` and ``width`` must be identical. Returns: - Embeddings, having shape ``(batch_size, num_patches, embed_dim)``. + Embeddings, having shape ``(batch_size, height / patch_size, + width / patch_size, embed_dim)``. """ - x = tf.ensure_shape(inputs, [None, None, None, 3]) - - x = self.proj(x, **kwargs) - x = self.flatten(x, **kwargs) + x = self.proj(inputs, **kwargs) - if self.norm: + if self.norm is not None: x = self.norm(x, **kwargs) return x @@ -125,90 +122,76 @@ class SwinPatchMerging(tf.keras.layers.Layer): patches. """ - def __init__(self, input_resolution: int, **kwargs) -> None: - # NOTE: Changed input_resolution from tuple to int - + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - assert input_resolution % 2 == 0 - self.input_resolution = input_resolution - self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-5) def build(self, input_shape: tf.TensorShape): - self.reduction = SwinLinear(input_shape[-1] * 2, use_bias=False) + assert input_shape.rank == 4 + assert input_shape[1] == input_shape[2] + assert input_shape[1] % 2 == 0 + + self.reduction = SwinLinear(input_shape[3] * 2, use_bias=False) def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: - """Perform the merging of patches. + """Merge groups of 4 neighbouring of patches. - The merge is performed on groups of 4 neighbouring patches. + This layer concatenates the features of groups of 4 neighbouring patches + and project the concatenation into a space twice the length of the + original feature space. Args: - inputs: Tensor of patches, with shape ``(batch_size, - num_patches, embed_dim)`` with - ``num_patches = input_resolution * input_resolution``. + inputs: Tensor of patches, with shape + ``(batch_size, height_patches, width_patches, embed_dim)`` with + ``height_patches`` must be equal to ``width_patches``. Returns: - Embeddings of merged patches, with shape ``(batch_size, num_patches / 4, 2 * embed_dim)``. + Embeddings of merged patches, with shape ``(batch_size, + heigth_patches /2, width_patches / 2, 2 * embed_dim)``. """ - tf.assert_equal(inputs.dtype, tf.float32, "Inputs must be a tf.float32 tensor.") - x = tf.ensure_shape(inputs, [None, self.input_resolution**2, None]) - - shape = tf.shape(inputs) - batch = shape[0] - channels = shape[2] - - x = tf.reshape( - x, [batch, self.input_resolution, self.input_resolution, channels] - ) - - x0 = x[:, 0::2, 0::2, :] - x1 = x[:, 1::2, 0::2, :] - x2 = x[:, 0::2, 1::2, :] - x3 = x[:, 1::2, 1::2, :] - - x = tf.concat([x0, x1, x2, x3], axis=-1) - x = tf.reshape(x, [batch, -1, 4 * channels]) + x = tf.concat( + [ + inputs[:, 0::2, 0::2, :], + inputs[:, 1::2, 0::2, :], + inputs[:, 0::2, 1::2, :], + inputs[:, 1::2, 1::2, :], + ], + axis=-1, + ) # [batch_size, height_patches / 2, width_patches / 2, 4 * embed_dim] x = self.norm(x, **kwargs) - x = self.reduction(x, **kwargs) + x = self.reduction( + x, **kwargs + ) # [batch_size, height_patches / 2, width_patches / 2, 2 * embed_dim] return x - def get_config(self) -> dict: - config = super().get_config() - config.update({"input_resolution": self.input_resolution}) - return config - def __repr__(self) -> str: - return f"{self.__class__.__name__}(input_resolution={self.input_resolution})" + return f"{self.__class__.__name__}()" class SwinStage(tf.keras.layers.Layer): """Stage of the Swin Network. Args: - input_resolution: The resolution of axes of the input, expressed in - number of patches. - depth: Number of SwinTransformer layers in the stage. - num_heads: Number of attention heads in each SwinTransformer layer. - window_size: The size of windows in which embeddings gets split into, - expressed in numer of patches. + depth: Number of ``SwinTransformer`` layers in the stage. + num_heads: Number of attention heads in each ``SwinTransformer`` layer. + window_size: The size of window axes expressed in patches. mlp_ratio: The ratio between the size of the hidden layer and the size - of the output layer in SwinMlp layers. - drop_p: The probability of dropping connections in a SwinTransformer + of the output layer in ``SwinMlp`` layers. + drop_p: The probability of dropping connections in a ``SwinTransformer`` layer during training. drop_path_p: The proabability of entirely skipping the computation of (Shifted) Windows Multi-head Self Attention during training (Stochastic Depth technique). - downsample: Whether or not to apply downsampling at the end of the - layer. + downsample: Whether or not to apply downsampling through a + ``SwinPatchMerging`` layer at the end of the stage. """ def __init__( self, - input_resolution: int, depth: int, num_heads: int, window_size: int, @@ -220,19 +203,17 @@ def __init__( ) -> None: super().__init__(**kwargs) - self.input_resolution = input_resolution self.depth = depth self.num_heads = num_heads self.window_size = window_size self.mlp_ratio = mlp_ratio self.drop_p = drop_p self.drop_path_p = drop_path_p - self.donwsample = downsample + self.downsample = downsample self.core = tf.keras.Sequential( [ SwinTransformer( - resolution=self.input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, @@ -246,29 +227,32 @@ def __init__( ] ) - if downsample: - self.downsample_layer = SwinPatchMerging(self.input_resolution) - else: - self.downsample_layer = None + self.downsample_layer = SwinPatchMerging() if downsample else None + + def build(self, input_shape: tf.TensorShape): + assert ( + input_shape.rank == 4 + ) # Must be batch_size, height_patches, width_patches, embed_dim + assert input_shape[1] == input_shape[2] def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """Apply transformations of the Swin stage to patches. Args: - inputs: The input patches to the Swin stage, having shape `` - (batch_size, num_patches, embed_dim)``. + inputs: The input patches to the Swin stage, having shape + ``(batch_size, height_patches, width_patches, embed_dim)``. + ``height_patches`` must be equal to ``width_patches``. Returns: - Transformed patches with shape ``(batch_size, num_patches / 4, - embed_dim * 2)`` if ``downsample == True`` or ``(batch_size, - num_patches, embed_dim)`` if ``downsample == False``. + Transformed patches with shape ``(batch_size, height_patches / 2, + width_patches / 2, embed_dim * 2)`` if ``downsample == True`` + or ``(batch_size, height_patches, width_patches, embed_dim)`` + if ``downsample == False``. """ - x = tf.ensure_shape(inputs, [None, None, None]) - - x = self.core(x, **kwargs) + x = self.core(inputs, **kwargs) - if self.donwsample: + if self.downsample: x = self.downsample_layer(x, **kwargs) return x @@ -277,28 +261,25 @@ def get_config(self) -> dict: config = super().get_config() config.update( { - "input_resolution": self.input_resolution, "depth": self.depth, "num_heads": self.num_heads, "window_size": self.window_size, "mlp_ratio": self.mlp_ratio, "drop_p": self.drop_p, "drop_path_p": self.drop_path_p, - "downsample": self.donwsample, + "downsample": self.downsample, } ) return config def __repr__(self) -> str: - return f"{self.__class__.__name__}(input_resolution={self.input_resolution}, depth={self.depth}, num_heads={self.num_heads}, window_size={self.window_size}, mlp_ratio={self.mlp_ratio}, drop_p={self.drop_p}, drop_path_p={self.drop_path_p}, downsample={self.donwsample})" + return f"{self.__class__.__name__}(depth={self.depth}, num_heads={self.num_heads}, window_size={self.window_size}, mlp_ratio={self.mlp_ratio}, drop_p={self.drop_p}, drop_path_p={self.drop_path_p}, downsample={self.downsample})" class SwinWindowAttention(tf.keras.layers.Layer): """Swin (Shifted) Window Multi-head Self Attention Layer. Args: - window_size: The size of windows in which embeddings gets divided into, - expressed in patches. num_heads: The number of attention heads. proj_drop_r: The ratio of output weights that randomly get dropped during training. @@ -306,42 +287,71 @@ class SwinWindowAttention(tf.keras.layers.Layer): def __init__( self, - window_size: int, num_heads: int, proj_drop_r: float = 0.0, **kwargs, ) -> None: super().__init__(**kwargs) - self.window_size = window_size self.num_heads = num_heads self.proj_drop_r = proj_drop_r - # TODO: Change into TF calls to get rid of numpy - coords_h = range(self.window_size) - coords_w = range(self.window_size) - coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) - coords_flat = np.reshape(coords, [coords.shape[0], -1]) - relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :] - relative_coords = np.transpose(relative_coords, [1, 2, 0]) - relative_coords[:, :, 0] += self.window_size - 1 - relative_coords[:, :, 1] += self.window_size - 1 - relative_coords[:, :, 0] *= 2 * self.window_size - 1 - relative_position_index = relative_coords.sum(-1) - - self.relative_position_index = tf.Variable( - initial_value=tf.convert_to_tensor(relative_position_index), - trainable=False, - name="relative_position_index", - ) - self.proj_drop = tf.keras.layers.Dropout(self.proj_drop_r) self.softmax = tf.keras.layers.Softmax(-1) + @classmethod + def build_relative_position_index(cls, window_size: int) -> tf.Tensor: + """Build the table of relative position indices. + + This table is used as an index to the relative position table. For each + pair of tokens in a window, this table allows to get the index in the + relative position table. + + Args: + window_size: The size of windows (expressed in patches) used during + the (S)W-MSA. + + Returns: + A ``Tensor`` with shape ``(window_size**2, window_size**2)`` + representing indices in the relative position table for each pair of + patches in the window. + """ + + coords = tf.range(0, window_size) + coords = tf.stack(tf.meshgrid(coords, coords, indexing="ij")) + coords = tf.reshape(coords, [tf.shape(coords)[0], -1]) + + rel_coords = tf.expand_dims(coords, 2) - tf.expand_dims( + coords, 1 + ) # Make values relative + rel_coords = tf.transpose(rel_coords, [1, 2, 0]) + + rel_coords = tf.Variable(rel_coords) + + rel_coords[:, :, 0].assign( + rel_coords[:, :, 0] + window_size - 1 + ) # Add offset to values + rel_coords[:, :, 1].assign(rel_coords[:, :, 1] + window_size - 1) + + rel_coords[:, :, 0].assign( + rel_coords[:, :, 0] * (2 * window_size - 1) + ) # Shift values so indices for different patches do not share the same value + + rel_pos_index = tf.reduce_sum(rel_coords, -1) + + return rel_pos_index + def build(self, input_shape: tf.TensorShape) -> None: - channels = input_shape[-1] + assert input_shape.rank == 5 + assert input_shape[2] == input_shape[3] + assert ( + input_shape[4] % self.num_heads == 0 + ) # embeddings dimension must be evenly divisible by the number of attention heads - self.head_dim = channels // self.num_heads + self.window_size = input_shape[2] + embed_dim = input_shape[4] + + self.head_dim = embed_dim // self.num_heads self.scale = self.head_dim**-0.5 # In the paper, sqrt(d) # The official implementation uses a custom function which defaults @@ -359,8 +369,17 @@ def build(self, input_shape: tf.TensorShape) -> None: trainable=True, ) - self.qkv = SwinLinear(channels * 3) - self.proj = SwinLinear(channels) + self.relative_position_index = tf.Variable( + initial_value=tf.reshape( + SwinWindowAttention.build_relative_position_index(self.window_size), + [-1], + ), # Flatten the matrix so it can be used to index the relative_position_bias_table in the forward pass + trainable=False, + name="relative_position_index", + ) + + self.qkv = SwinLinear(embed_dim * 3) + self.proj = SwinLinear(embed_dim) def call( self, inputs: tf.Tensor, mask: tf.Tensor | None = None, **kwargs @@ -368,8 +387,8 @@ def call( """Perform (Shifted) Window MSA. Args: - inputs: Embeddings with shape ``(num_windows * batch_size, - window_size * window_size, embed_dim)``. ``embed_dim`` must be + inputs: Embeddings with shape ``(batch_size, num_windows, + window_size, window_size, embed_dim)``. ``embed_dim`` must be exactly divisible by ``num_heads``. mask: Attention mask used used to perform Shifted Window MSA, having shape ``(num_windows, window_size * window_size, window_size * window_size)`` and values {0, -inf}. @@ -379,25 +398,21 @@ def call( input. """ - x = tf.ensure_shape(inputs, [None, self.window_size**2, None]) - shape = tf.shape(inputs) - batch_windows = shape[0] - window_dim = shape[1] - embed_dim = shape[2] - - tf.assert_equal( - embed_dim % self.num_heads, - 0, - "Provided input dimension 3 (embed_dim) is not evenly divisible by the number of attention heads.", - ) + batch_windows = shape[0] * shape[1] + window_dim = shape[2] * shape[3] + embed_dim = shape[4] - qkv = self.qkv(x, **kwargs) + x = tf.reshape(inputs, [batch_windows, window_dim, embed_dim]) + + qkv = self.qkv(x, **kwargs) # [batch_windows, window_dim, 3 * embed_dim] qkv = tf.reshape( qkv, - [batch_windows, window_dim, 3, self.num_heads, embed_dim // self.num_heads], + [batch_windows, window_dim, 3, self.num_heads, self.head_dim], ) - qkv = tf.transpose(qkv, [2, 0, 3, 1, 4]) + qkv = tf.transpose( + qkv, [2, 0, 3, 1, 4] + ) # [3, batch_windows, num_heads, window_dim, head_dim] q = qkv[0] k = qkv[1] @@ -405,40 +420,58 @@ def call( q = q * self.scale - attn = tf.matmul(q, tf.transpose(k, [0, 1, 3, 2])) + attn = tf.matmul( + q, k, transpose_b=True + ) # [batch_windows, num_heads, window_dim, window_dim] - indices = tf.reshape(self.relative_position_index, [-1]) - relative_position_bias = tf.gather(self.relative_position_bias_table, indices) + relative_position_bias = tf.gather( + self.relative_position_bias_table, self.relative_position_index + ) # [window_dim**2, num_heads] relative_position_bias = tf.reshape( relative_position_bias, [window_dim, window_dim, -1] ) - relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1]) + relative_position_bias = tf.transpose( + relative_position_bias, [2, 0, 1] + ) # [num_heads, window_dim, window_dim] - attn = attn + tf.expand_dims(relative_position_bias, axis=0) + attn = attn + tf.expand_dims( + relative_position_bias, axis=0 + ) # [batch_windows, num_heads, window_dim, window_dim] if mask is not None: - nW = tf.shape(mask)[0] + num_windows = tf.shape(mask)[0] attn = tf.reshape( - attn, [batch_windows // nW, nW, self.num_heads, window_dim, window_dim] - ) + attn, + [ + batch_windows // num_windows, + num_windows, + self.num_heads, + window_dim, + window_dim, + ], + ) # Expand to [batch_size, num_windows, num_heads, window_dim, windo_dim] in order to sum the attention mask attn = attn + tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0) - attn = tf.reshape(attn, [-1, self.num_heads, window_dim, window_dim]) + attn = tf.reshape( + attn, [-1, self.num_heads, window_dim, window_dim] + ) # Back to [batch_windows, num_heads, window_dim, window_dim] attn = self.softmax(attn, **kwargs) - x = tf.matmul(attn, v) - x = tf.transpose(x, [0, 2, 1, 3]) - x = tf.reshape(x, [batch_windows, window_dim, embed_dim]) - x = self.proj(x, **kwargs) - x = self.proj_drop(x, **kwargs) + attn = tf.matmul(attn, v) + attn = tf.transpose(attn, [0, 2, 1, 3]) + attn = tf.reshape(attn, [batch_windows, window_dim, embed_dim]) - return x + attn = self.proj(attn, **kwargs) + attn = self.proj_drop(attn, **kwargs) + + attn = tf.reshape(attn, tf.shape(inputs)) + + return attn def get_config(self) -> dict: config = super().get_config() config.update( { - "window_size": self.window_size, "num_heads": self.num_heads, "proj_drop_r": self.proj_drop_r, } @@ -446,23 +479,40 @@ def get_config(self) -> dict: return config def __repr__(self) -> str: - return f"{self.__class__.__name__}(window_size={self.window_size}, num_heads={self.num_heads}, proj_drop_r={self.proj_drop_r})" + return f"{self.__class__.__name__}(num_heads={self.num_heads}, proj_drop_r={self.proj_drop_r})" class SwinDropPath(tf.keras.layers.Layer): - """Stochastic Depth Layer. + """Stochastic per-sample layer drop. + + This is an implementation of the stochastic depth technique described in the + "Deep Networks with Stochastic Depth" paper by Huang et al. + (https://arxiv.org/pdf/1603.09382.pdf). + + Examples in a batch have a probability to have their values set to 0. + This is useful in conjunction with residual paths, as adding the residual + connection with 0 yields the original example, as if other computations + never took place in the main path. Args: - drop_prob: The probability of entirely skipping the output of the - computation. + drop_prob: The probability of entirely skipping the layer. """ def __init__(self, drop_prob: float = 0.0, **kwargs) -> None: super().__init__(**kwargs) + assert drop_prob >= 0 and drop_prob <= 1 + self.drop_prob = drop_prob self.keep_prob = 1 - self.drop_prob + def build(self, input_shape: tf.TensorShape) -> None: + # We want to get a rank-1 tensor, with tf.rank(inputs) values all set to + # 1 except for the first one, identical to the batch size. + # e.g. [4, 1, 1, 1]. + self.shape = tf.ones([input_shape.rank], dtype=tf.int32) + self.shape = tf.tensor_scatter_nd_update(self.shape, [[0]], [input_shape[0]]) + def call( self, inputs: tf.Tensor, training: tf.Tensor = None, **kwargs ) -> tf.Tensor: @@ -472,8 +522,8 @@ def call( inputs: The input data. The first dimension is assumed to be the ``batch_size``. training: Whether the forward pass is happening at training time - or not. During inference (``training`` = False) ``inputs`` is - returned as-is. + or not. During inference (``training = False``) ``inputs`` is + returned as-is (i.e. no drops). Returns: The input tensor with some values randomly set to 0. @@ -482,21 +532,9 @@ def call( if self.drop_prob == 0 or not training: return inputs - first_axis = tf.expand_dims(tf.shape(inputs)[0], axis=0) - other_axis = tf.repeat( - 1, tf.rank(inputs) - 1 - ) # Rank-1 tensor with (rank(inputs) - 1) axes, all having value 1 - - # We want to get a rank-1 tensor with 1 as the value of all axes except - # for the first one, identical to the batch size - shape = tf.concat( - [first_axis, other_axis], - axis=0, - ) - rand_tensor = tf.constant(self.keep_prob, dtype=inputs.dtype) rand_tensor = rand_tensor + tf.random.uniform( - shape, maxval=1.0, dtype=inputs.dtype + self.shape, maxval=1.0, dtype=inputs.dtype ) rand_tensor = tf.floor(rand_tensor) @@ -535,20 +573,22 @@ def __init__( self.fc2 = SwinLinear(self.out_features) self.drop = tf.keras.layers.Dropout(self.drop_p) + def build(self, input_shape: tf.TensorShape) -> None: + assert input_shape.rank == 4 + def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """Apply the transformations of the MLP. Args: - inputs: The input data, having shape ``(batch_size, num_patches, - embed_size)``. + inputs: The input data, having shape + ``(batch_size, height_patches, width_patches, embed_size)``. Returns: - The transformed inputs, with shape ``(batch_size, num_patches, - out_features)``. + The transformed inputs, with shape + ``(batch_size, num_patches, out_features)``. """ - x = tf.ensure_shape(inputs, [None, None, None]) - x = self.fc1(x, **kwargs) + x = self.fc1(inputs, **kwargs) x = tf.nn.gelu(x) x = self.drop(x, **kwargs) x = self.fc2(x, **kwargs) @@ -562,6 +602,7 @@ def get_config(self) -> dict: { "hidden_features": self.hidden_features, "out_features": self.out_features, + "drop_p": self.drop_p, } ) return config @@ -574,25 +615,20 @@ class SwinTransformer(tf.keras.layers.Layer): """Swin Transformer Layer. Args: - resolution: The input resolution expressed in number of patches per - axis. Both axis share the same resolution as the orginal image - must be a square. - num_heads: The number of Shifted Window Attention heads. - window_size: The size of windows in which the image gets partitioned - into, expressed in patches. + num_heads: The number of (Shifted) Window Attention heads. + window_size: The size of window axes, expressed in patches. shift_size: The value of shifting applied to windows, expressed in patches. mlp_ratio: The ratio between the size of the hidden layer and the - size of the output layer in SwinMlp. - drop_p: The probability of dropping connections in Dropout layers during - training. + size of the output layer in ``SwinMlp``. + drop_p: The probability of dropping connections in ``Dropout`` layers + during training. drop_path_p: The probability of entirely skipping a (Shifted) Windows Multi-head Self Attention computation during training. """ def __init__( self, - resolution: int, num_heads: int, window_size: int, shift_size: int, @@ -603,7 +639,6 @@ def __init__( ) -> None: super().__init__(**kwargs) - self.resolution = resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size @@ -611,82 +646,34 @@ def __init__( self.drop_p = drop_p self.drop_path_p = drop_path_p - if self.resolution <= self.window_size: - self.shift_size = 0 - self.window_size = self.resolution - - # Resolution must be evenly divisible by the window size or reshape - # operations will not work - assert self.resolution % self.window_size == 0 - - assert 0 <= self.shift_size < self.window_size - self.norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-5) - self.attention = SwinWindowAttention( - self.window_size, self.num_heads, proj_drop_r=drop_p - ) + self.attention = SwinWindowAttention(self.num_heads, proj_drop_r=drop_p) # When drop_path_p == 0 SwinDropPath simply returns the same value self.drop_path = SwinDropPath(self.drop_path_p) self.norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-5) - if self.shift_size > 0: - attn_mask = self.build_attn_mask( - self.resolution, - self.window_size, - self.shift_size, - ) - - self.attn_mask = tf.Variable( - initial_value=attn_mask, - trainable=False, - name="attention_mask", - ) - else: - self.attn_mask = None - @classmethod - @tf.function( - input_signature=[ - tf.TensorSpec(shape=[None, None, None, None], dtype=tf.float32), - tf.TensorSpec(shape=[], dtype=tf.int32), - ], - ) def window_partition(cls, patches: tf.Tensor, window_size: tf.Tensor) -> tf.Tensor: """Partition a batch of images into windows. - .. Note:: - - This method may throw warnings due to an excessive number of - retracing operations. - However, due to it being used in the forward pass of the full - model, keeping it decorated as a ``tf.function`` should still prove - to be beneficial. - Args: - patches: Patch embeddings for a batch of images to partition - into windows, having shape ``(batch_size, num_patches_h, - num_patches_w, embed_dim)``. ``num_patches_h == num_patches_w``. - window_size: The size of each window, expressed in patches. + patches: A batch of patch embeddings to partition into windows, + having shape ``(batch_size, num_patches_h, num_patches_w, + embed_dim)``. + window_size: The size of each window, expressed in patches along + each axis. Returns: - A tensor of windows having shape ``(n * batch_size, window_size, + A tensor of windows having shape ``(batch_size, n, window_size, window_size, embed_dim)``, where ``n`` is the number of resulting windows. """ - x = tf.ensure_shape(patches, [None, None, None, None]) - window_size = tf.ensure_shape(window_size, []) - - shape = tf.shape(x) - tf.assert_equal( - shape[1], - shape[2], - "The number of patches in the height dimension must be equal to the number of patches in the width dimension (patches must be squared).", - ) + shape = tf.shape(patches) windows = tf.reshape( - x, + patches, [ shape[0], shape[1] // window_size, @@ -697,104 +684,52 @@ def window_partition(cls, patches: tf.Tensor, window_size: tf.Tensor) -> tf.Tens ], ) windows = tf.transpose(windows, [0, 1, 3, 2, 4, 5]) - windows = tf.reshape(windows, [-1, window_size, window_size, shape[3]]) + windows = tf.reshape( + windows, [shape[0], -1, window_size, window_size, shape[3]] + ) return windows @classmethod - @tf.function - def window_reverse(cls, windows: tf.Tensor, patch_size: tf.Tensor) -> tf.Tensor: + def window_reverse(cls, windows: tf.Tensor, resolution: tf.Tensor) -> tf.Tensor: """Reverse the partitioning of a batch of patches into windows. + .. Note: + ``resolution`` is expected to be a multiple of the size of windows. + No checks are performed to ensure this holds. + Args: - windows: Partitioned windows to reverse, with shape ``(batch_size * - num_windows, window_size, window_size, embed_dim)``. - patch_size: Number of patches per axis in the original image. + windows: Partitioned windows to reverse, with shape + ``(batch_size, num_windows, window_size, window_size, embed_dim)``. + resolution: Number of patches per axis in the original feature map. Returns: A tensor of patches of the batch recreated from ``windows``, with - shape ``(batch_size, patch_size, patch_size, embed_dim)``. + shape ``(batch_size, resolution, resolution, embed_dim)``. """ - x = tf.ensure_shape(windows, [None, None, None, None]) - - tf.assert_equal( - tf.shape(x)[1], - tf.shape(x)[2], - "Dimension 1 and dimension 2 of 'windows' must be identical.", - ) - window_size = tf.shape(x)[1] + shape = tf.shape(windows) - # TODO: simplify - b = tf.cast(tf.shape(x)[0], tf.float64) # Casting to prevent type mismatch - d = patch_size**2 / window_size / tf.cast(window_size, tf.float64) - batch_size = tf.cast(b / d, tf.int32) + batch_size = shape[0] + window_size = shape[2] + embed_dim = shape[4] x = tf.reshape( - x, + windows, [ batch_size, - patch_size // window_size, - patch_size // window_size, + resolution // window_size, + resolution // window_size, window_size, window_size, - -1, + embed_dim, ], ) x = tf.transpose(x, [0, 1, 3, 2, 4, 5]) - x = tf.reshape(x, [batch_size, patch_size, patch_size, -1]) + x = tf.reshape(x, [batch_size, resolution, resolution, embed_dim]) return x - @classmethod - def masked_fill( - cls, tensor: tf.Tensor, mask: tf.Tensor, value: tf.Tensor - ) -> tf.Tensor: - """Fill elements of ``tensor`` with ``value`` where ``mask`` is True. - - This function returns a new tensor having the same values as ``tensor`` - except for those where ``mask`` contained the value True; these values are - replaced with ``value``. - - It mimics ``torch.tensor.masked_fill()``. - - ``mask`` must have identical shape to ``tensor`` and ``value`` must be a - scalar tensor. - ``value`` is cast to the type of ``tensor`` if their types don't match. - - Args: - tensor: The tensor to fill with ``value`` where ``mask`` is True. - mask: The mask to apply to ``tensor``. - value: The value to fill ``tensor`` with. - - Returns: - A copy of ``tensor`` with elements changed to ``value`` where - ``mask`` was ``True``. - """ - - tf.assert_equal( - tf.shape(tensor), - tf.shape(mask), - "The shape of tensor must match the shape of mask.", - ) - tf.assert_equal(tf.rank(value), 0, "'value' must be a scalar tensor.") - - if value.dtype != tensor.dtype: - value = tf.cast(value, tensor.dtype) - - indices = tf.where(mask) - - filled_tensor = tf.tensor_scatter_nd_update( - tensor, - indices, - tf.broadcast_to( - value, - [tf.shape(indices)[0]], - ), - ) - - return filled_tensor - @classmethod def build_attn_mask(cls, size: int, window_size: int, shift_size: int): """Build an attention mask for the Shifted Window MSA. @@ -808,7 +743,10 @@ def build_attn_mask(cls, size: int, window_size: int, shift_size: int): The computed attention mask, with shape ``(num_windows, window_size * window_size, window_size * window_size)``. """ - # TODO: Change mask creation to ditch numpy + # While possible to build the mask only through TensorFlow operations, + # it would result in a much less readable method.Since Numpy is already + # a TensorFlow dependency and this method is only called during this + # layer's initialization, using it to build the mask is fine. mask = np.zeros( [1, size, size, 1], dtype=np.float32 ) # Force type so we get a tf.float32 tensor as the output of this method. @@ -831,19 +769,62 @@ def build_attn_mask(cls, size: int, window_size: int, shift_size: int): mask_windows = SwinTransformer.window_partition( tf.convert_to_tensor(mask), tf.constant(window_size) - ) - mask_windows = tf.reshape(mask_windows, [-1, window_size * window_size]) + ) # mask_windows.shape = [n, window_size, window_size, 1]. + mask_windows = tf.reshape( + mask_windows, [-1, window_size * window_size] + ) # mask_windows.shape = [n, window_size**2], we flatten windows. + + # We need to create a mask which, for each patch in each window, tells + # us if the attention mechanism should be calculated for every other + # patch in the same window. + # This means a mask with shape [n, window_size**2, window_size**2]. + # Subtracting the two expanded mask_windows gives us a tensor with the + # right shape and values equal to zero where two patches are adjacent in + # the original feature map (meaning attention should be calculated). attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) - attn_mask = SwinTransformer.masked_fill( - attn_mask, attn_mask != 0, tf.constant(-100.0) - ) # TODO: check if -100 can be changed to -math.inf - attn_mask = SwinTransformer.masked_fill( - attn_mask, attn_mask == 0, tf.constant(0.0) - ) + + # We now need to change values != 0 to something negative. When put + # through the SoftMax operation performed during the SW-MSA, it results + # in a value close to 0 for those patches that were not adjacent in the + # original feature map. + # Technically, the bigger the negative number the better + # (i.e. -math.inf), but it could lead to float values shenanigans so we + # choose -100 to stay consistent with the original implementation. + attn_mask = tf.where(attn_mask != 0, tf.constant(-100.0), attn_mask) return attn_mask def build(self, input_shape: tf.TensorShape) -> None: + assert input_shape.rank == 4 + assert input_shape[1] == input_shape[2] + + self.resolution = input_shape[1] + + # Resolution must be evenly divisible by the window size or reshape + # operations will not work + assert self.resolution % self.window_size == 0 + + if self.resolution <= self.window_size: + self.shift_size = 0 + self.window_size = self.resolution + + assert 0 <= self.shift_size < self.window_size + + if self.shift_size > 0: + attn_mask = self.build_attn_mask( + self.resolution, + self.window_size, + self.shift_size, + ) + + self.attn_mask = tf.Variable( + initial_value=attn_mask, + trainable=False, + name="attention_mask", + ) + else: + self.attn_mask = None + dim = input_shape[-1] mlp_hidden_dim = int(dim * self.mlp_ratio) self.mlp = SwinMlp(mlp_hidden_dim, out_features=dim, drop_p=self.drop_p) @@ -852,57 +833,35 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: """Apply the transformations of the transformer layer. Args: - inputs: Input embeddings with shape ``(batch_size, num_patches, embed_dim)``. + inputs: Input embeddings with shape + ``(batch_size, height_patches, width_patches, embed_dim)``. + ``height_patches`` must be equal to ``width_patches``. Returns: Transformed embeddings with same shape as ``inputs``. """ - x = tf.ensure_shape(inputs, [None, self.resolution * self.resolution, None]) + shortcut_1 = inputs - shape = tf.shape(inputs) - - batch = shape[0] - channels = shape[2] - - shortcut_1 = x - - x = self.norm_1(x, **kwargs) - x = tf.reshape(x, [batch, self.resolution, self.resolution, channels]) - shifted_x = x + # Layer normalization + x = self.norm_1(inputs, **kwargs) + # Cyclic shift if self.shift_size > 0: - shifted_x = tf.roll( - x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2] - ) + x = tf.roll(x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]) # Window partitioning - x_windows = self.window_partition(shifted_x, self.window_size) - x_windows = tf.reshape( - x_windows, [-1, self.window_size * self.window_size, channels] - ) + x = self.window_partition(x, self.window_size) # (Shifted) Window Multi-head Self Attention - attn_windows = self.attention(x_windows, mask=self.attn_mask, **kwargs) + x = self.attention(x, mask=self.attn_mask, **kwargs) - # Window merging - attn_windows = tf.reshape( - attn_windows, [-1, self.window_size, self.window_size, channels] - ) - shifted_x = self.window_reverse( - attn_windows, - tf.constant(self.resolution), - ) + # Undo window partitioning (window merging) + x = self.window_reverse(x, tf.constant(self.resolution)) - # Reverse cyclic shift + # Undo cyclic shift (reverse cyclic shift) if self.shift_size > 0: - x = tf.roll( - shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2] - ) - else: - x = shifted_x - - x = tf.reshape(x, [batch, self.resolution * self.resolution, channels]) + x = tf.roll(x, shift=[self.shift_size, self.shift_size], axis=[1, 2]) # Sum the skip connection and the output of (S)W-MSA x = shortcut_1 + self.drop_path(x, **kwargs) @@ -921,7 +880,6 @@ def get_config(self) -> dict: config = super().get_config() config.update( { - "resolution": self.resolution, "num_heads": self.num_heads, "window_size": self.window_size, "shift_size": self.shift_size, @@ -933,4 +891,4 @@ def get_config(self) -> dict: return config def __repr__(self) -> str: - return f"{self.__class__.__name__}(resolution={self.resolution}, window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, drop_p={self.drop_p}, drop_path_p={self.drop_path_p})" + return f"{self.__class__.__name__}(window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, drop_p={self.drop_p}, drop_path_p={self.drop_path_p})" diff --git a/tests/test_model.py b/tests/test_model.py index f3fa4bf..13c5a21 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,9 +1,11 @@ import pathlib import random import tempfile +import unittest + import tensorflow as tf + import swin.model as sm -import unittest class TestSwin(unittest.TestCase): @@ -34,6 +36,7 @@ def setUp(self) -> None: self.embedding_to_head_ratio = 32 self.embed_dim = self.num_heads[0] * self.embedding_to_head_ratio self.drop_rate = random.random() + self.drop_path_rate = random.random() self.input = tf.random.uniform( [self.batch_size, self.img_size, self.img_size, self.img_channels], @@ -41,13 +44,13 @@ def setUp(self) -> None: ) self.model = sm.Swin( - self.input, - self.num_classes, - self.patch_size, - self.embed_dim, - self.depths, - self.num_heads, - self.drop_rate, + num_classes=self.num_classes, + patch_size=self.patch_size, + embed_dim=self.embed_dim, + depths=self.depths, + num_heads=self.num_heads, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, ) def _build_dataset(self) -> tf.data.Dataset: @@ -88,6 +91,44 @@ def test_model_output(self) -> None: self.assertEqual(output.shape[0], self.batch_size) self.assertEqual(output.shape[1], self.num_classes) + def test_model_variants_output(self) -> None: + variants = [sm.SwinT, sm.SwinS, sm.SwinB, sm.SwinL] + image_size = 224 + + for variant in variants: + with self.subTest(f"Variant {variant}"): + model = variant(num_classes=self.num_classes, drop_rate=self.drop_rate) + inputs = tf.random.uniform([self.batch_size, image_size, image_size, 3]) + output = model(inputs) + + self.assertEqual(output.shape[0], self.batch_size) + self.assertEqual(output.shape[1], self.num_classes) + + def test_model_custom_window_size_output(self) -> None: + depths = [2, 4, 2] + num_heads = [4, 8, 16] + patch_size = 6 + window_size = 8 + embed_dim = num_heads[0] * 32 + img_size = 384 + num_classes = random.randint(1, 10) + batch_size = 2 ** random.randint(1, 3) + + inputs = tf.random.uniform([batch_size, img_size, img_size, 3]) + model = sm.Swin( + num_classes=num_classes, + patch_size=patch_size, + window_size=window_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + ) + + output = model(inputs) + + self.assertEqual(output.shape[0], batch_size) + self.assertEqual(output.shape[1], num_classes) + def test_model_compile(self) -> None: self.model(self.input) @@ -121,11 +162,12 @@ def test_model_restore(self) -> None: output_2 = self.model(self.input) - diff = tf.abs(output_1 - output_2) - diff = diff * 0.01 # We tolerate a 1% difference - diff = tf.floor(diff) - diff = tf.cast(diff, tf.bool) - self.assertEqual(tf.reduce_any(diff), False) + self.assertEqual( + tf.reduce_all( + tf.raw_ops.ApproximateEqual(x=output_1, y=output_2, tolerance=1e-2) + ), + True, + ) # We tolerate a 1% difference def test_model_restore_config(self) -> None: output_1 = self.model(self.input) diff --git a/tests/test_modules.py b/tests/test_modules.py index 74544d0..9740e92 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,15 +1,17 @@ """Module containing tests for the modules of the Swin Transformer network.""" -import tensorflow as tf import random import unittest + +import tensorflow as tf + import swin.modules as sm class TestSwinLinear(unittest.TestCase): def setUp(self) -> None: self.batch_size = random.randint(1, 5) - self.img_size = 2 ** random.randint(5, 10) # 32-1024 px + self.img_size = 2 ** random.randint(3, 8) # 8-256 self.input = tf.random.uniform( [self.batch_size, self.img_size, self.img_size, 3], dtype=tf.float32 @@ -56,22 +58,13 @@ def setUp(self) -> None: self.wrong_input_shape = tf.random.uniform( [self.batch_size, self.img_size, 2 * self.img_size, 3], dtype=tf.float32 ) - self.wrong_input_dtype = tf.image.convert_image_dtype( - tf.random.uniform( - [self.batch_size, self.img_size, self.img_size, 3], - maxval=255, - dtype=tf.float32, - ), - tf.uint8, - ) - self.wrong_input_channels = self.wrong_input_shape = tf.random.uniform( + self.wrong_input_channels = tf.random.uniform( [self.batch_size, self.img_size, self.img_size, 1], dtype=tf.float32 ) self.wrong_inputs = [ self.wrong_input_shape, self.wrong_input_channels, - self.wrong_input_dtype, ] self.embed_dim = 2 ** random.randint(5, 8) # 32-256 @@ -87,12 +80,12 @@ def test_output(self) -> None: self.assertEqual(shape[0], self.batch_size) self.assertEqual(shape[-1], self.embed_dim) - self.assertEqual(output.dtype, tf.float32) + self.assertEqual(output.dtype, self.input.dtype) def test_wrong_input(self) -> None: for input_data in self.wrong_inputs: - with self.subTest(input_data): - self.assertRaises(Exception, self.layer, input_data) + with self.subTest(f"{input_data.shape}, {input_data.dtype}"): + self.assertRaises(AssertionError, self.layer, input_data) def test_trainable_variables(self) -> None: # Build the layer @@ -117,14 +110,15 @@ def test_gradient(self) -> None: class TestSwinPatchMerging(unittest.TestCase): def setUp(self) -> None: self.batch_size = random.randint(1, 5) - self.patch_size = 2 * random.randint(7, 28) # Any even number would be fine + self.patches = 2 * random.randint(7, 28) # Any even number would be fine self.embed_dim = 2 ** random.randint(5, 8) # 32-256 self.input = tf.random.uniform( - [self.batch_size, self.patch_size**2, self.embed_dim], dtype=tf.float32 + [self.batch_size, self.patches, self.patches, self.embed_dim], + dtype=tf.float32, ) - self.layer = sm.SwinPatchMerging(self.patch_size) + self.layer = sm.SwinPatchMerging() def test_build_odd_patch_size(self) -> None: self.assertRaises(Exception, sm.SwinPatchMerging, 7) # Any odd number is fine @@ -134,26 +128,21 @@ def test_output(self) -> None: shape = output.shape - self.assertEqual(len(shape), 3) + self.assertEqual(len(shape), 4) self.assertEqual(shape[0], self.batch_size) - self.assertEqual(shape[1], self.patch_size**2 / 4) - self.assertEqual(shape[2], self.embed_dim * 2) - self.assertEqual(output.dtype, tf.float32) + self.assertEqual(shape[1], shape[2]) + self.assertEqual(shape[1], self.patches / 2) + self.assertEqual(shape[3], self.embed_dim * 2) + self.assertEqual(output.dtype, self.input.dtype) def test_wrong_input(self) -> None: self.wrong_input_shape = tf.random.uniform( - [self.batch_size, self.patch_size * (self.patch_size * 2), self.embed_dim], + [self.batch_size, self.patches**2, self.embed_dim], dtype=tf.float32, ) - self.wrong_input_dtype = tf.random.uniform( - [self.batch_size, self.patch_size**2, self.embed_dim], - maxval=255, - dtype=tf.int32, - ) self.wrong_inputs = [ self.wrong_input_shape, - self.wrong_input_dtype, ] for input_data in self.wrong_inputs: @@ -186,6 +175,10 @@ def setUp(self) -> None: [4, 224, 224, 3], dtype=tf.float32 ) # Any shape and dtype would be ok + def test_wrong_probability(self) -> None: + self.assertRaises(AssertionError, sm.SwinDropPath, -1.2) + self.assertRaises(AssertionError, sm.SwinDropPath, 2.0) + def test_output_dtype(self) -> None: layer = sm.SwinDropPath(0.5) @@ -235,18 +228,27 @@ def setUp(self) -> None: self.layer = sm.SwinMlp(self.hidden_features, self.out_features, self.drop_p) self.batch_size = random.randint(1, 5) - self.input = tf.random.uniform([self.batch_size, 768, 96], dtype=tf.float32) + self.resolution = random.randint(1, 10) + self.embed_dim = random.randint(1, 100) + self.input = tf.random.uniform( + [self.batch_size, self.resolution, self.resolution, self.embed_dim], + dtype=tf.float32, + ) def test_output(self) -> None: output = self.layer(self.input) shape = output.shape - self.assertEqual(shape[0], self.batch_size) + self.assertEqual( + shape[:-1], [self.batch_size, self.resolution, self.resolution] + ) self.assertEqual(shape[-1], self.out_features) - self.assertEqual(output.dtype, tf.float32) + self.assertEqual(output.dtype, self.input.dtype) def test_wrong_input(self) -> None: - wrong_input = tf.random.uniform([self.batch_size, 224, 224, 96]) + wrong_input = tf.random.uniform( + [self.batch_size, self.resolution**2, self.embed_dim] + ) self.assertRaises(Exception, self.layer, wrong_input) @@ -274,13 +276,11 @@ class TestWindowAttention(unittest.TestCase): def setUp(self) -> None: self.window_size = 2 * random.randint( 1, 5 - ) # Could also be odd, but it simplifies the following operations + ) # Could also be odd, but it simplifies some future operations self.num_heads = random.randint(1, 5) self.proj_drop_r = random.random() - self.layer = sm.SwinWindowAttention( - self.window_size, self.num_heads, self.proj_drop_r - ) + self.layer = sm.SwinWindowAttention(self.num_heads, self.proj_drop_r) self.batch_size = random.randint(1, 5) self.embed_dim = self.num_heads * random.randint( @@ -289,12 +289,14 @@ def setUp(self) -> None: self.num_patches = self.window_size * random.randint( 2, 10 ) # Must be divisible by window_size - self.num_windows = int((self.num_patches / self.window_size)) ** 2 + self.num_windows = (self.num_patches // self.window_size) ** 2 self.input = tf.random.uniform( [ - self.batch_size * self.num_windows, - self.window_size**2, + self.batch_size, + self.num_windows, + self.window_size, + self.window_size, self.embed_dim, ], dtype=tf.float32, @@ -304,6 +306,15 @@ def setUp(self) -> None: self.num_patches, self.window_size, self.window_size // 2 ) + def test_build_relative_position_index_output(self) -> None: + output = self.layer.build_relative_position_index(self.window_size) + + self.assertEqual(output.dtype, tf.int32) + self.assertEqual(output.shape, [self.window_size**2, self.window_size**2]) + self.assertEqual( + output[0, self.window_size**2 - 1], 0 + ) # Top-right corner must be 0 + def test_output_no_shift(self) -> None: output = self.layer(self.input) @@ -322,25 +333,30 @@ def test_wrong_input(self) -> None: tf.random.uniform( [ self.batch_size * self.num_windows, - self.window_size**2 - 1, + self.window_size**2, self.embed_dim, ], dtype=tf.float32, ) ) # Wrong shape - wrong_inputs.append( - tf.random.uniform( - [ - self.batch_size * self.num_windows, - self.window_size**2, - self.num_heads * 2 + 1, - ], - dtype=tf.float32, + + if self.num_heads > 1: + # Incompatible emebedding dimensions. Only add this test when + # num_heads is greater than 1 or it will fail as n % 1 = 0 for any + # n. + wrong_inputs.append( + tf.random.uniform( + [ + self.batch_size * self.num_windows, + self.window_size**2, + self.num_heads - 1, + ], + dtype=tf.float32, + ) ) - ) # Incompatible emebedding dimensions for input_data in wrong_inputs: - with self.subTest(input_data): + with self.subTest(f"{input_data.shape}, {input_data.dtype}"): self.assertRaises(Exception, self.layer, input_data) def test_trainable_variables(self) -> None: @@ -365,19 +381,20 @@ def test_gradient(self) -> None: class TestSwinTransformer(unittest.TestCase): def setUp(self) -> None: - self.resolution = 2 ** random.randint(3, 6) # 8-64 + self.resolution = 2 ** random.randint(3, 6) # 8-64 patches self.num_heads = random.randint(1, 4) # Window size must be evenly divisible by resolution and > 0. We choose # > 1 to simplify tests where having window_size = resolution would # require a lot more code - self.window_size = int(self.resolution / (2 ** random.randint(1, 2))) - self.shift_size = int(self.window_size / 2) + self.window_size = self.resolution // ( + 2 ** random.randint(1, 2) + ) # 4-16 patches + self.shift_size = self.window_size // 2 self.mlp_ratio = 4.0 self.drop_p = random.random() self.drop_path_p = random.random() self.layer = sm.SwinTransformer( - self.resolution, self.num_heads, self.window_size, self.shift_size, @@ -386,14 +403,14 @@ def setUp(self) -> None: self.drop_path_p, ) - self.batch_size = 2 ** random.randint(0, 3) - self.patch_size = self.resolution**2 + self.batch_size = 2 ** random.randint(0, 3) # 1-8 self.embed_dim = self.num_heads * random.randint( 10, 20 ) # Any multiple of num_heads would be fine self.input = tf.random.uniform( - [self.batch_size, self.patch_size, self.embed_dim], dtype=tf.float32 + [self.batch_size, self.resolution, self.resolution, self.embed_dim], + dtype=tf.float32, ) def test_window_partition_wrong_inputs(self) -> None: @@ -404,10 +421,10 @@ def test_window_partition_wrong_inputs(self) -> None: [self.batch_size, self.resolution, self.resolution + 1, self.embed_dim], dtype=tf.float32, ) - ) # Not squared patches + ) # Non-square patches wrong_inputs.append( tf.random.uniform( - [self.batch_size, self.patch_size, self.embed_dim], + [self.batch_size, self.resolution**2, self.embed_dim], dtype=tf.float32, ) ) # Wrong rank @@ -429,10 +446,39 @@ def test_window_partition_output(self) -> None: output = sm.SwinTransformer.window_partition(input_data, self.window_size) self.assertEqual(output.dtype, input_data.dtype) + self.assertEqual(tf.rank(output), 5) self.assertEqual(output.shape[0] % self.batch_size, 0) - self.assertEqual(output.shape[1], self.window_size) - self.assertEqual(output.shape[1], output.shape[2]) - self.assertEqual(output.shape[3], self.embed_dim) + self.assertEqual(output.shape[2], output.shape[3]) + self.assertEqual(output.shape[2], self.window_size) + self.assertEqual(output.shape[4], self.embed_dim) + + def test_window_partition_order(self) -> None: + batch_size = 1 + resolution = 4 + embed_dim = 1 + window_size = 2 + + window_res = resolution // window_size + + input_data = tf.reshape( + tf.range(batch_size * resolution**2 * embed_dim), + [batch_size, resolution, resolution, embed_dim], + ) + output = sm.SwinTransformer.window_partition(input_data, window_size) + + for batch in range(batch_size): + for i in range(window_res): + for j in range(window_res): + win_idx = i * window_res + j + with self.subTest(f"window {win_idx}"): + out_win = output[batch, win_idx] + true_win = input_data[ + batch, + i * window_size : (i + 1) * window_size, + j * window_size : (j + 1) * window_size, + ] + + self.assertTrue(tf.reduce_all(tf.equal(out_win, true_win))) def test_window_reverse_wrong_inputs(self) -> None: wrong_inputs = [] @@ -440,26 +486,32 @@ def test_window_reverse_wrong_inputs(self) -> None: wrong_inputs.append( tf.random.uniform( [ - self.batch_size * int((self.resolution / self.window_size)) ** 2, + self.batch_size, + (self.resolution // self.window_size) ** 2, self.window_size, self.window_size + 1, self.embed_dim, ], dtype=tf.float32, ) - ) # Not squared windows + ) # Non-square windows wrong_inputs.append( tf.random.uniform( - [self.batch_size, self.window_size**2, self.embed_dim], + [ + self.batch_size, + (self.resolution // self.window_size) ** 2, + self.window_size**2, + self.embed_dim, + ], dtype=tf.float32, ) ) # Wrong rank for input_data in wrong_inputs: - with self.subTest(): + with self.subTest(f"{input_data.shape}, {input_data.dtype}"): self.assertRaises( Exception, - sm.SwinTransformer.window_partition, + sm.SwinTransformer.window_reverse, input_data, self.window_size, ) @@ -467,7 +519,8 @@ def test_window_reverse_wrong_inputs(self) -> None: def test_window_reverse_output(self) -> None: input_data = tf.random.uniform( [ - self.batch_size * int((self.resolution / self.window_size)) ** 2, + self.batch_size, + (self.resolution // self.window_size) ** 2, self.window_size, self.window_size, self.embed_dim, @@ -482,33 +535,6 @@ def test_window_reverse_output(self) -> None: output.shape, ) - def test_masked_fill_mask_wrong_shape(self) -> None: - x = random.randint(1, 100) - y = random.randint(1, 100) - z = random.randint(1, 10) - value = 5 - - input = tf.random.uniform([x, y], dtype=tf.float32) - mask = tf.ones([x, y, z], dtype=tf.bool) - - self.assertRaises( - Exception, sm.SwinTransformer.masked_fill, input, mask, tf.constant(value) - ) - - def test_masked_fill_output(self) -> None: - x = random.randint(1, 100) - y = random.randint(1, 100) - z = random.randint(1, 10) - value = 5 - - input = tf.random.uniform([x, y, z], dtype=tf.float32) - mask = tf.ones(input.shape, dtype=tf.float32) - output = sm.SwinTransformer.masked_fill(input, mask, tf.constant(value)) - - self.assertEqual(input.shape, output.shape) - self.assertEqual(input.dtype, output.dtype) - self.assertEqual(tf.reduce_all(output == value), True) - def test_build_attn_mask_output(self) -> None: output = sm.SwinTransformer.build_attn_mask( self.resolution, self.window_size, self.shift_size @@ -518,6 +544,9 @@ def test_build_attn_mask_output(self) -> None: self.assertEqual((self.resolution / self.window_size) ** 2, output.shape[0]) self.assertEqual(self.window_size**2, output.shape[1]) self.assertEqual(output.shape[1], output.shape[2]) + self.assertTrue( + tf.reduce_all(tf.logical_or(output == -100.0, output == 0)) + ) # No value other than 0 or -100 should be present def test_shift_size_bigger_than_window_size(self) -> None: self.assertRaises( @@ -563,20 +592,20 @@ def test_wrong_input(self) -> None: wrong_inputs.append( tf.random.uniform( - [self.batch_size, self.resolution**2 - 1, self.embed_dim], + [self.batch_size, self.resolution, self.resolution - 1, self.embed_dim], dtype=tf.float32, ) - ) # Wrong num_patches + ) # Non-square patches wrong_inputs.append( tf.random.uniform( - [self.batch_size, self.patch_size, self.embed_dim], + [self.batch_size, self.resolution, self.resolution, self.embed_dim], maxval=255, dtype=tf.int32, ) ) # Wrong dtype wrong_inputs.append( tf.random.uniform( - [self.batch_size, self.resolution, self.resolution, self.embed_dim], + [self.batch_size, self.resolution**2, self.embed_dim], dtype=tf.float32, ) ) # Wrong rank @@ -607,29 +636,27 @@ def test_gradient(self) -> None: class TestSwinStage(unittest.TestCase): def setUp(self) -> None: - self.resolution = 2 ** random.randint(3, 6) # 8-64 px + self.resolution = 2 ** random.randint(3, 6) # 8-64 patches self.depth = random.randint(1, 4) self.num_heads = random.randint(1, 4) - self.window_size = int( - self.resolution / (2 ** random.randint(1, 3)) + self.window_size = self.resolution // ( + 2 ** random.randint(1, 3) ) # Must be evenly divisible by resolution and > 0 self.mlp_ratio = 4.0 self.drop_p = random.random() self.drop_path_p = random.random() self.batch_size = 2 ** random.randint(0, 4) - self.num_patches = self.resolution**2 self.embed_dim = self.num_heads * random.randint( 10, 100 ) # Must be evenly divisible by num_heads self.input = tf.random.uniform( - [self.batch_size, self.num_patches, self.embed_dim], dtype=tf.float32 + [self.batch_size, self.resolution, self.resolution, self.embed_dim], + dtype=tf.float32, ) - def test_wrong_inputs(self) -> None: - layer = sm.SwinStage( - self.resolution, + self.layer_ds = sm.SwinStage( self.depth, self.num_heads, self.window_size, @@ -638,17 +665,7 @@ def test_wrong_inputs(self) -> None: self.drop_path_p, downsample=True, ) - - wrong_input = tf.random.uniform( - [self.batch_size, self.resolution, self.resolution, self.embed_dim], - dtype=tf.float32, - ) - - self.assertRaises(Exception, layer, wrong_input) - - def test_output_no_downsample(self) -> None: - layer = sm.SwinStage( - self.resolution, + self.layer_no_ds = sm.SwinStage( self.depth, self.num_heads, self.window_size, @@ -658,66 +675,58 @@ def test_output_no_downsample(self) -> None: downsample=False, ) - output = layer(self.input) + def test_wrong_inputs(self) -> None: + wrong_inputs = list() + + wrong_inputs.append( + tf.random.uniform( + [self.batch_size, self.resolution**2, self.embed_dim], + dtype=tf.float32, + ) + ) # Wrong rank + wrong_inputs.append( + tf.random.uniform( + [self.batch_size, self.resolution, self.resolution + 1, self.embed_dim], + dtype=tf.float32, + ) + ) # Non-square input + + for input_data in wrong_inputs: + with self.subTest(f"{input_data.shape}, {input_data.dtype}"): + self.assertRaises(AssertionError, self.layer_ds, input_data) + + def test_output_no_downsample(self) -> None: + output = self.layer_no_ds(self.input) self.assertEqual(output.dtype, self.input.dtype) self.assertEqual(output.shape, self.input.shape) def test_output_downsample(self) -> None: - layer = sm.SwinStage( - self.resolution, - self.depth, - self.num_heads, - self.window_size, - self.mlp_ratio, - self.drop_p, - self.drop_path_p, - downsample=True, - ) - - output = layer(self.input) + output = self.layer_ds(self.input) self.assertEqual(output.dtype, self.input.dtype) + self.assertEqual(len(output.shape), 4) self.assertEqual(output.shape[0], self.input.shape[0]) - self.assertEqual(output.shape[1], self.input.shape[1] / 4) - self.assertEqual(output.shape[2], self.input.shape[2] * 2) + self.assertEqual(output.shape[1], output.shape[2]) + self.assertEqual(output.shape[1], self.input.shape[1] / 2) + self.assertEqual(output.shape[3], self.input.shape[3] * 2) def test_trainable_variables(self) -> None: - layer = sm.SwinStage( - self.resolution, - self.depth, - self.num_heads, - self.window_size, - self.mlp_ratio, - self.drop_p, - self.drop_path_p, - downsample=True, - ) # Build the layer - layer(self.input) + self.layer_ds(self.input) - t_vars = layer.trainable_variables + t_vars = self.layer_ds.trainable_variables self.assertEqual(len(t_vars), 13 * self.depth + 3) def test_gradient(self) -> None: - layer = sm.SwinStage( - self.resolution, - self.depth, - self.num_heads, - self.window_size, - self.mlp_ratio, - self.drop_p, - self.drop_path_p, - downsample=True, - ) # Build the layer - layer(self.input) + self.layer_ds(self.input) with tf.GradientTape() as gt: - output = layer(self.input) + output = self.layer_ds(self.input) - gradients = gt.gradient(output, layer.trainable_variables) + gradients = gt.gradient(output, self.layer_ds.trainable_variables) self.assertNotIn(None, gradients)