Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Pytorch video (#216)
Browse files Browse the repository at this point in the history
* update

* update

* Update flash/vision/video/classification/data.py

* update

* Update flash/vision/video/classification/model.py

Co-authored-by: Kaushik B <[email protected]>

* update

* update

* typo

* update

* update

* resolve some internal bugs

* update on comments

* move files

* update

* update

* update

* filter for 3.6

* update on comments

* update

* update

* update

* clean auto dataset

* typo

* update

* update on comments:

* add doc

* remove backbone section

* update

* update

* update

* update

* map to None

* update

* update

* update on comments

* update script

* update on comments

* Update docs/source/reference/video_classification.rst

* update

* update

* update

* update

* Updates

* update

* update

* update

* update

* iupdate:

* update

* update

* resolve ci

* update

* update

* updates

* update

* update

* update

* update

* update

* update

* update

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
4 people authored Apr 30, 2021
1 parent 2b39eae commit 0263fd1
Show file tree
Hide file tree
Showing 28 changed files with 1,007 additions and 89 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ jobs:
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
- name: Filter requirements
run: |
import sys
if sys.version_info.minor < 7:
fname = 'requirements.txt'
lines = [line for line in open(fname).readlines() if not line.startswith('pytorchvideo')]
open(fname, 'w').writelines(lines)
shell: python

# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- name: Get pip cache
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ docs/notebooks/
docs/api/
titanic.csv
.vscode
.venv
data_folder
*.pt
*.zip
Expand All @@ -149,5 +150,6 @@ imdb
xsum
coco128
wmt_en_ro
action_youtube_naudio
kinetics
movie_posters
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Lightning Flash
reference/tabular_classification
reference/translation
reference/object_detection
reference/video_classification


.. toctree::
:maxdepth: 1
Expand Down
156 changes: 156 additions & 0 deletions docs/source/reference/video_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@

.. _video_classification:

####################
Video Classification
####################

********
The task
********

Typically, Video Classification refers to the task of producing a label for actions identified in a given video.

The task predicts which ‘class’ the video clip most likely belongs to with a degree of certainty.

A class is a label that describes what action is being performed within the video clip, such as **swimming** , **playing piano**, etc.

For example, we can train the video classifier task on video clips with human actions
and it will learn to predict the probability that a video contains a certain human action.

Lightning Flash :class:`~flash.video.VideoClassifier` and :class:`~flash.video.VideoClassificationData`
relies on `PyTorchVideo <https://pytorchvideo.readthedocs.io/en/latest/index.html>`_ internally.

You can use any models from `PyTorchVideo Model Zoo <https://pytorchvideo.readthedocs.io/en/latest/model_zoo.html>`_
with the :class:`~flash.video.VideoClassifier`.

------

**********
Finetuning
**********

Let's say you wanted to develop a model that could determine whether a video clip contains a human **swimming** or **playing piano**,
using the `Kinetics dataset <https://deepmind.com/research/open-source/kinetics>`_.
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.video.VideoClassificationData`.

.. code-block::
video_dataset
├── train
│ ├── class_1
│ │ ├── a.ext
│ │ ├── b.ext
│ │ ...
│ └── class_n
│ ├── c.ext
│ ├── d.ext
│ ...
└── val
├── class_1
│ ├── e.ext
│ ├── f.ext
│ ...
└── class_n
├── g.ext
├── h.ext
...
.. code-block:: python
import sys
import torch
from torch.utils.data import SequentialSampler
import flash
from flash.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
import kornia.augmentation as K
from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip
# 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")
# 2. [Optional] Specify transforms to be used during training.
# Flash helps you to place your transform exactly where you want.
# Learn more at https://lightning-flash.readthedocs.io/en/latest/general/data.html#flash.data.process.Preprocess
train_transform = {
"post_tensor_transform": Compose([
ApplyTransformToKey(
key="video",
transform=Compose([
UniformTemporalSubsample(8),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(244),
RandomHorizontalFlip(p=0.5),
]),
),
]),
"per_batch_transform_on_device": Compose([
ApplyTransformToKey(
key="video",
transform=K.VideoSequential(
K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
data_format="BCTHW",
same_on_frame=False
)
),
]),
}
# 3. Load the data from directories.
datamodule = VideoClassificationData.from_paths(
train_data_path="data/kinetics/train",
val_data_path="data/kinetics/val",
predict_data_path="data/kinetics/predict",
clip_sampler="uniform",
clip_duration=2,
video_sampler=SequentialSampler,
decode_audio=False,
train_transform=train_transform
)
# 4. List the available models
print(VideoClassifier.available_models())
# out: ['efficient_x3d_s', 'efficient_x3d_xs', ... ,slowfast_r50', 'x3d_m', 'x3d_s', 'x3d_xs']
# 5. Build the model
model = VideoClassifier(model="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)
# 6. Train the model
trainer = flash.Trainer(fast_dev_run=True)
# 6. Finetune the model
trainer.finetune(model, datamodule=datamodule)
predictions = model.predict("data/kinetics/train/archery/-1q7jA3DXQM_000005_000015.mp4")
print(predictions)
------

*************
API reference
*************

.. _video_classifier:

VideoClassifier
---------------

.. autoclass:: flash.video.VideoClassifier
:members:
:exclude-members: forward

.. _video_classification_data:

VideoClassificationData
-----------------------

.. autoclass:: flash.video.VideoClassificationData

.. automethod:: flash.video.VideoClassificationData.from_paths
1 change: 0 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class Labels(Classes):
def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5):
super().__init__(multi_label=multi_label, threshold=threshold)
self._labels = labels
self.set_state(ClassificationState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None
Expand Down
1 change: 1 addition & 0 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
train_bn: Whether to train Batch Norm layer
"""
super().__init__()

self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names
self.train_bn = train_bn
Expand Down
12 changes: 11 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import functools
from importlib import import_module
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import inspect
import torch
import torchmetrics
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -325,6 +325,9 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
getattr(data_pipeline, '_postprocess_pipeline', None),
getattr(data_pipeline, '_serializer', None),
)
self._preprocess.state_dict()
if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None):
self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore

@property
def preprocess(self) -> Preprocess:
Expand Down Expand Up @@ -394,6 +397,13 @@ def available_models(cls) -> List[str]:
return []
return registry.available_keys()

@classmethod
def get_model_details(cls, key) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
if registry is None:
return []
return [v for v in inspect.signature(registry.get(key)).parameters.items()]

@classmethod
def available_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None)
Expand Down
9 changes: 3 additions & 6 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import subprocess
from abc import ABC, ABCMeta, abstractclassmethod, abstractmethod, abstractproperty, abstractstaticmethod
from abc import ABC, abstractclassmethod, abstractmethod
from dataclasses import dataclass
from importlib import import_module
from operator import truediv
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -339,6 +335,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
preprocess_state_dict["_meta"]["class_name"] = self.__class__.__name__
preprocess_state_dict["_meta"]["_state"] = self._state
destination['preprocess.state_dict'] = preprocess_state_dict
self._ddp_params_and_buffers_to_ignore = ['preprocess.state_dict']
return super()._save_to_state_dict(destination, prefix, keep_vars)

def _check_transforms(self, transform: Optional[Dict[str, Callable]],
Expand Down
4 changes: 2 additions & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
max_length: int,
target: str,
filetype: str,
train_file: Optional[str],
label_to_class_mapping: Optional[Dict[str, int]],
train_file: Optional[str] = None,
label_to_class_mapping: Optional[Dict[str, int]] = None,
):
"""
This class contains the preprocessing logic for text classification
Expand Down
Loading

0 comments on commit 0263fd1

Please sign in to comment.