Skip to content

Commit

Permalink
Add revamped docs for video classification models (#5894)
Browse files Browse the repository at this point in the history
* Add revamped docs for video classification models

* EOL
  • Loading branch information
NicolasHug authored Apr 27, 2022
1 parent 36c4635 commit f014841
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def generate_weights_table(module, table_name, metrics):
generate_weights_table(
module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")]
)
generate_weights_table(module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")])


def setup(app):
Expand Down
26 changes: 26 additions & 0 deletions docs/source/models/video_resnet.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Video ResNet
============

.. currentmodule:: torchvision.models.video

The VideoResNet model is based on the `A Closer Look at Spatiotemporal
Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__ paper.


Model builders
--------------

The following model builders can be used to instantiate a VideoResNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

r3d_18
mc3_18
r2plus1d_18
21 changes: 21 additions & 0 deletions docs/source/models_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,24 @@ Table of all available detection weights
Box MAPs are reported on COCO

.. include:: generated/detection_table.rst


Video Classification
====================

.. currentmodule:: torchvision.models.video

The following video classification models are available, with or without
pre-trained weights:

.. toctree::
:maxdepth: 1

models/video_resnet

Table of all available video classification weights
---------------------------------------------------

Accuracies are reported on Kinetics-400

.. include:: generated/video_table.rst
69 changes: 48 additions & 21 deletions torchvision/models/video/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,24 @@ class R2Plus1D_18_Weights(WeightsEnum):

@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer Resnet3D model as in
https://arxiv.org/abs/1711.11248
"""Construct 18 layer Resnet3D model.
Args:
weights (R3D_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
Returns:
VideoResNet: R3D-18 network
Args:
weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.R3D_18_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.R3D_18_Weights
:members:
"""
weights = R3D_18_Weights.verify(weights)

Expand All @@ -390,15 +399,24 @@ def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, *

@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for 18 layer Mixed Convolution network as in
https://arxiv.org/abs/1711.11248
"""Construct 18 layer Mixed Convolution network as in
Args:
weights (MC3_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
Returns:
VideoResNet: MC3 Network definition
Args:
weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.MC3_18_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.MC3_18_Weights
:members:
"""
weights = MC3_18_Weights.verify(weights)

Expand All @@ -415,15 +433,24 @@ def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, *

@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for the 18 layer deep R(2+1)D network as in
https://arxiv.org/abs/1711.11248
"""Construct 18 layer deep R(2+1)D network as in
Args:
weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
Returns:
VideoResNet: R(2+1)D-18 network
Args:
weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.R2Plus1D_18_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
:members:
"""
weights = R2Plus1D_18_Weights.verify(weights)

Expand Down

0 comments on commit f014841

Please sign in to comment.