Skip to content

Commit

Permalink
Add Video Swin Transformer (#2369)
Browse files Browse the repository at this point in the history
* init video swin

* add: 3d window size computation

* add: mlp layer

* add: patch embedding layer

* add: patch merging layer

* add: window attention layer

* add: basic layer for video swin

* update: basic layer for video swin

* add: swin blocks for video swin

* create and add: video swin backbone

* rename: video swin layers to model specific

* update module import

* update module import

* set class method to private usage

* set init params for backbone

* rm redundant imports

* add video swin layer test cases

* add: videoswin backbone aliases

* add: video swin backbone presets

* add: video swin backbone presets test

* update: video swin backbone presets test

* add: video classifier task

* add: video swin classifier presets

* run formatters

* rename module name/id"

* add hard-coded normalization for include rescaling=true

* add docstring for videoswin backbone

* update metadata: backbone presets no weights

* update: backbone presets no weights test

* update video swin aliases for no weights

* add: video swin backbone presets with weights

* update: video swin aliases with weights presets

* update video swin layer test cases

* added patch merging test

* imported video swins presets to backbone presets list"

* fix: typos"

* run formatters"

* fix: linting issue

* fix: linting issue

* fix: video swin layer test cases"

* add: video swin backbone test

* rm redundant code

* disable preset test temporary

* set include rescale to true

* add video swin components to __init__

* update docstrings: video siwn layers scripts

* update copywrite status: video siwn layers test scripts

* update copywrite status: video siwn backbone scripts

* bug fixes: video swin backbone layers

* update get config of video swin backbone

* enable: video swin backbone test cases

* update: video swin backbone test cases

* update: video swin backbone preset test cases

* run formatters

* fix typos: video swin backbone test cases

* add: non implemented property for test reason

* fix: typos

* add: video classifier test

* update: video classifier test

* update: video classifier test input shape

* bug fix: mlp layer build method

* updated: swin back layer build method

* bug fix: use tf.TensorShape in compute_output_shape method

* update: video_classifier_test model.predict to model.call

* update test cases and format the code

* update docstrings and preset config

* fix jax DynamicJaxprTrace issue for

* update config of backbone aliases

* add can run in mixed precision test

* add can run on gray video

* minor fix

* specify axis in keras.ops.take to match with tf.gather

* specify include rescaling to backbone class

* remove shift size form get config of video basic layer

* add support arbitrary input shape

* minor updates to swin layers

* test method update for swin layers

* update test method to swin backbone

* remove unsed code

* bug fix in call method of patch embed layer

* fix typo in patch merging layer

* minor fix

* fix keras.ops.cond issue with jax

* no test for jit compile in torch

* reduce tensor size for forward test

* minor fix

* remove kcv export decorator

* update keras.Layer import

* remove unused layer import

* replace keras.layers instead of layers

* update keras.Layer to keras.layers.Layer for keras2

* add window_size param to aliases

* move vide swin layer to model specific directory

* minor fix
  • Loading branch information
innat authored Apr 5, 2024
1 parent c123d51 commit bfeba12
Show file tree
Hide file tree
Showing 13 changed files with 2,171 additions and 0 deletions.
13 changes: 13 additions & 0 deletions keras_cv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,24 @@
ResNetV2Backbone,
)
from keras_cv.models.backbones.vgg16.vgg16_backbone import VGG16Backbone
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
VideoSwinBBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
VideoSwinSBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_aliases import (
VideoSwinTBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_backbone import (
VideoSwinBackbone,
)
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetHBackbone
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone
from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone
from keras_cv.models.classification.image_classifier import ImageClassifier
from keras_cv.models.classification.video_classifier import VideoClassifier
from keras_cv.models.feature_extractor.clip import CLIP
from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import (
Expand Down
3 changes: 3 additions & 0 deletions keras_cv/models/backbones/backbone_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_cv.models.backbones.mobilenet_v3 import mobilenet_v3_backbone_presets
from keras_cv.models.backbones.resnet_v1 import resnet_v1_backbone_presets
from keras_cv.models.backbones.resnet_v2 import resnet_v2_backbone_presets
from keras_cv.models.backbones.video_swin import video_swin_backbone_presets
from keras_cv.models.backbones.vit_det import vit_det_backbone_presets
from keras_cv.models.object_detection.yolo_v8 import yolo_v8_backbone_presets

Expand All @@ -42,6 +43,7 @@
**efficientnet_lite_backbone_presets.backbone_presets_no_weights,
**yolo_v8_backbone_presets.backbone_presets_no_weights,
**vit_det_backbone_presets.backbone_presets_no_weights,
**video_swin_backbone_presets.backbone_presets_no_weights,
}

backbone_presets_with_weights = {
Expand All @@ -55,6 +57,7 @@
**efficientnet_lite_backbone_presets.backbone_presets_with_weights,
**yolo_v8_backbone_presets.backbone_presets_with_weights,
**vit_det_backbone_presets.backbone_presets_with_weights,
**video_swin_backbone_presets.backbone_presets_with_weights,
}

backbone_presets = {
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/models/backbones/video_swin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
158 changes: 158 additions & 0 deletions keras_cv/models/backbones/video_swin/video_swin_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_cv.models.backbones.video_swin.video_swin_backbone import (
VideoSwinBackbone,
)
from keras_cv.models.backbones.video_swin.video_swin_backbone_presets import (
backbone_presets,
)
from keras_cv.utils.python_utils import classproperty

ALIAS_DOCSTRING = """VideoSwin{size}Backbone model.
Reference:
- [Video Swin Transformer](https://arxiv.org/abs/2106.13230)
- [Video Swin Transformer GitHub](https://github.com/SwinTransformer/Video-Swin-Transformer)
For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Examples:
```python
input_data = np.ones(shape=(1, 32, 224, 224, 3))
# Randomly initialized backbone
model = VideoSwin{size}Backbone()
output = model(input_data)
```
""" # noqa: E501


class VideoSwinTBackbone(VideoSwinBackbone):
def __new__(
cls,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
include_rescaling=True,
**kwargs,
):
kwargs.update(
{
"embed_dim": embed_dim,
"depths": depths,
"num_heads": num_heads,
"window_size": window_size,
"include_rescaling": include_rescaling,
}
)
return VideoSwinBackbone.from_preset("videoswin_tiny", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"videoswin_tiny_kinetics400": copy.deepcopy(
backbone_presets["videoswin_tiny_kinetics400"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


class VideoSwinSBackbone(VideoSwinBackbone):
def __new__(
cls,
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
include_rescaling=True,
**kwargs,
):
kwargs.update(
{
"embed_dim": embed_dim,
"depths": depths,
"num_heads": num_heads,
"window_size": window_size,
"include_rescaling": include_rescaling,
}
)
return VideoSwinBackbone.from_preset("videoswin_small", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"videoswin_small_kinetics400": copy.deepcopy(
backbone_presets["videoswin_small_kinetics400"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


class VideoSwinBBackbone(VideoSwinBackbone):
def __new__(
cls,
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[8, 7, 7],
include_rescaling=True,
**kwargs,
):
kwargs.update(
{
"embed_dim": embed_dim,
"depths": depths,
"num_heads": num_heads,
"window_size": window_size,
"include_rescaling": include_rescaling,
}
)
return VideoSwinBackbone.from_preset("videoswin_base", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"videoswin_base_kinetics400": copy.deepcopy(
backbone_presets["videoswin_base_kinetics400"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


setattr(VideoSwinTBackbone, "__doc__", ALIAS_DOCSTRING.format(size="T"))
setattr(VideoSwinSBackbone, "__doc__", ALIAS_DOCSTRING.format(size="S"))
setattr(VideoSwinBBackbone, "__doc__", ALIAS_DOCSTRING.format(size="B"))
Loading

0 comments on commit bfeba12

Please sign in to comment.