From f0148413640d3e9efa774797a84f021ba8c356ea Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 27 Apr 2022 13:17:56 +0100 Subject: [PATCH] Add revamped docs for video classification models (#5894) * Add revamped docs for video classification models * EOL --- docs/source/conf.py | 1 + docs/source/models/video_resnet.rst | 26 +++++++++++ docs/source/models_new.rst | 21 +++++++++ torchvision/models/video/resnet.py | 69 ++++++++++++++++++++--------- 4 files changed, 96 insertions(+), 21 deletions(-) create mode 100644 docs/source/models/video_resnet.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 2e648bf959d..e54e8bfb582 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -379,6 +379,7 @@ def generate_weights_table(module, table_name, metrics): generate_weights_table( module=M.segmentation, table_name="segmentation", metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")] ) +generate_weights_table(module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")]) def setup(app): diff --git a/docs/source/models/video_resnet.rst b/docs/source/models/video_resnet.rst new file mode 100644 index 00000000000..a3f92b546b9 --- /dev/null +++ b/docs/source/models/video_resnet.rst @@ -0,0 +1,26 @@ +Video ResNet +============ + +.. currentmodule:: torchvision.models.video + +The VideoResNet model is based on the `A Closer Look at Spatiotemporal +Convolutions for Action Recognition `__ paper. + + +Model builders +-------------- + +The following model builders can be used to instantiate a VideoResNet model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.video.resnet.VideoResNet`` base class. Please refer to the `source +code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + r3d_18 + mc3_18 + r2plus1d_18 diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index f3924205710..41a7b3d482f 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -101,3 +101,24 @@ Table of all available detection weights Box MAPs are reported on COCO .. include:: generated/detection_table.rst + + +Video Classification +==================== + +.. currentmodule:: torchvision.models.video + +The following video classification models are available, with or without +pre-trained weights: + +.. toctree:: + :maxdepth: 1 + + models/video_resnet + +Table of all available video classification weights +--------------------------------------------------- + +Accuracies are reported on Kinetics-400 + +.. include:: generated/video_table.rst diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 777057a088a..5b0d7d99bca 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -365,15 +365,24 @@ class R2Plus1D_18_Weights(WeightsEnum): @handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - """Construct 18 layer Resnet3D model as in - https://arxiv.org/abs/1711.11248 + """Construct 18 layer Resnet3D model. - Args: - weights (R3D_18_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. - Returns: - VideoResNet: R3D-18 network + Args: + weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.R3D_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.R3D_18_Weights + :members: """ weights = R3D_18_Weights.verify(weights) @@ -390,15 +399,24 @@ def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, * @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - """Constructor for 18 layer Mixed Convolution network as in - https://arxiv.org/abs/1711.11248 + """Construct 18 layer Mixed Convolution network as in - Args: - weights (MC3_18_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. - Returns: - VideoResNet: MC3 Network definition + Args: + weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MC3_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MC3_18_Weights + :members: """ weights = MC3_18_Weights.verify(weights) @@ -415,15 +433,24 @@ def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, * @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - """Constructor for the 18 layer deep R(2+1)D network as in - https://arxiv.org/abs/1711.11248 + """Construct 18 layer deep R(2+1)D network as in - Args: - weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model - progress (bool): If True, displays a progress bar of the download to stderr + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. - Returns: - VideoResNet: R(2+1)D-18 network + Args: + weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.R2Plus1D_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights + :members: """ weights = R2Plus1D_18_Weights.verify(weights)