Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pytorch video #216

Merged
merged 70 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
ae190b9
update
tchaton Apr 15, 2021
a4cb3a3
update
tchaton Apr 15, 2021
cdba489
Update flash/vision/video/classification/data.py
kaushikb11 Apr 15, 2021
c6b1cd7
update
tchaton Apr 15, 2021
5f1aae9
Merge branch 'pytorch_video' of https://github.com/PyTorchLightning/l…
tchaton Apr 15, 2021
2c427f6
Update flash/vision/video/classification/model.py
tchaton Apr 15, 2021
0c4a092
update
tchaton Apr 15, 2021
db6df41
Merge branch 'pytorch_video' of https://github.com/PyTorchLightning/l…
tchaton Apr 15, 2021
19ea5f1
update
tchaton Apr 15, 2021
b21f152
typo
tchaton Apr 15, 2021
aea8214
update
tchaton Apr 15, 2021
fbc43c8
update
tchaton Apr 15, 2021
73e0191
resolve some internal bugs
tchaton Apr 16, 2021
a1ff7b6
update on comments
tchaton Apr 16, 2021
3227e77
move files
tchaton Apr 16, 2021
98b4e13
update
tchaton Apr 16, 2021
9e17b50
update
tchaton Apr 16, 2021
eb286de
update
tchaton Apr 16, 2021
b122059
filter for 3.6
tchaton Apr 16, 2021
ae8197d
update on comments
tchaton Apr 16, 2021
c4526f4
update
tchaton Apr 16, 2021
0c2f852
update
tchaton Apr 16, 2021
c949061
update
tchaton Apr 16, 2021
fa30ea5
clean auto dataset
tchaton Apr 16, 2021
2777b9e
typo
tchaton Apr 16, 2021
17bfe73
update
tchaton Apr 16, 2021
b9bae51
update on comments:
tchaton Apr 16, 2021
38c9610
add doc
tchaton Apr 16, 2021
8a04ceb
remove backbone section
tchaton Apr 16, 2021
383f939
update
tchaton Apr 16, 2021
ab21afa
update
tchaton Apr 16, 2021
11bdd62
update
tchaton Apr 16, 2021
3ac8437
update
tchaton Apr 16, 2021
5a9158b
map to None
tchaton Apr 16, 2021
8ad791b
update
tchaton Apr 16, 2021
4feef51
update
tchaton Apr 16, 2021
35bb690
update on comments
tchaton Apr 16, 2021
1b4d565
update script
tchaton Apr 16, 2021
912fce0
update on comments
tchaton Apr 16, 2021
c6919f4
Update docs/source/reference/video_classification.rst
carmocca Apr 16, 2021
aeb6fee
Merge branch 'master' into pytorch_video
tchaton Apr 18, 2021
25bea44
Merge branch 'master' into pytorch_video
tchaton Apr 19, 2021
ab4b6d4
Merge branch 'master' into pytorch_video
tchaton Apr 19, 2021
480aa18
Merge branch 'master' into pytorch_video
tchaton Apr 27, 2021
41bdc5b
update
tchaton Apr 27, 2021
3c660ef
Merge branch 'master' into pytorch_video
tchaton Apr 27, 2021
04382a5
update
tchaton Apr 27, 2021
cf3ef94
update
tchaton Apr 27, 2021
9e9b656
Merge branch 'pytorch_video' of https://github.com/PyTorchLightning/l…
tchaton Apr 27, 2021
754f43c
update
tchaton Apr 27, 2021
7a09783
Updates
ethanwharris Apr 27, 2021
6697e91
update
tchaton Apr 27, 2021
231171a
update
tchaton Apr 27, 2021
92aa151
update
tchaton Apr 27, 2021
939a251
update
tchaton Apr 27, 2021
63babc6
iupdate:
tchaton Apr 27, 2021
530367d
update
tchaton Apr 29, 2021
81733ed
update
tchaton Apr 29, 2021
fdd85a2
Merge branch 'master' into pytorch_video
tchaton Apr 29, 2021
ed043b3
resolve ci
tchaton Apr 29, 2021
1735b6f
update
tchaton Apr 29, 2021
5d80e45
update
tchaton Apr 30, 2021
f201a70
updates
tchaton Apr 30, 2021
aff7657
update
tchaton Apr 30, 2021
b18457a
update
tchaton Apr 30, 2021
78dbc3a
update
tchaton Apr 30, 2021
80f7e71
update
tchaton Apr 30, 2021
1999639
update
tchaton Apr 30, 2021
3b0bd8f
update
tchaton Apr 30, 2021
43e2bc3
update
tchaton Apr 30, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ jobs:
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"

- name: Filter requirements
run: |
import sys
if sys.version_info.minor < 7:
fname = 'requirements.txt'
lines = [line for line in open(fname).readlines() if not line.startswith('pytorchvideo')]
open(fname, 'w').writelines(lines)
shell: python

# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- name: Get pip cache
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ docs/notebooks/
docs/api/
titanic.csv
.vscode
.venv
data_folder
*.pt
*.zip
Expand All @@ -149,5 +150,6 @@ imdb
xsum
coco128
wmt_en_ro
action_youtube_naudio
kinetics
movie_posters
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Lightning Flash
reference/tabular_classification
reference/translation
reference/object_detection
reference/video_classification


.. toctree::
:maxdepth: 1
Expand Down
156 changes: 156 additions & 0 deletions docs/source/reference/video_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@

.. _video_classification:

####################
Video Classification
####################

********
The task
********

Typically, Video Classification refers to the task of producing a label for actions identified in a given video.

The task predicts which ‘class’ the video clip most likely belongs to with a degree of certainty.

A class is a label that describes what action is being performed within the video clip, such as **swimming** , **playing piano**, etc.

For example, we can train the video classifier task on video clips with human actions
and it will learn to predict the probability that a video contains a certain human action.

Lightning Flash :class:`~flash.video.VideoClassifier` and :class:`~flash.video.VideoClassificationData`
relies on `PyTorchVideo <https://pytorchvideo.readthedocs.io/en/latest/index.html>`_ internally.

You can use any models from `PyTorchVideo Model Zoo <https://pytorchvideo.readthedocs.io/en/latest/model_zoo.html>`_
with the :class:`~flash.video.VideoClassifier`.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

------

**********
Finetuning
**********

Let's say you wanted to develop a model that could determine whether a video clip contains a human **swimming** or **playing piano**,
using the `Kinetics dataset <https://deepmind.com/research/open-source/kinetics>`_.
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.video.VideoClassificationData`.

.. code-block::

video_dataset
├── train
│ ├── class_1
│ │ ├── a.ext
│ │ ├── b.ext
│ │ ...
│ └── class_n
│ ├── c.ext
│ ├── d.ext
│ ...
└── val
├── class_1
│ ├── e.ext
│ ├── f.ext
│ ...
└── class_n
├── g.ext
├── h.ext
...


.. code-block:: python

import sys

import torch
from torch.utils.data import SequentialSampler

import flash
from flash.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
import kornia.augmentation as K
from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip

# 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")

# 2. [Optional] Specify transforms to be used during training.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# Flash helps you to place your transform exactly where you want.
# Learn more at https://lightning-flash.readthedocs.io/en/latest/general/data.html#flash.data.process.Preprocess
train_transform = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to make these default task transforms to keep the example code as minimal as possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added on purpose, to show to the users how to play with transforms.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we were thinking of adding a transform_recipe.py file to PTV transforms package which has all the default torchhub model recipes. So in the future we can potentially change it to use that.

"post_tensor_transform": Compose([
ApplyTransformToKey(
key="video",
transform=Compose([
UniformTemporalSubsample(8),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(244),
RandomHorizontalFlip(p=0.5),
]),
),
]),
"per_batch_transform_on_device": Compose([
ApplyTransformToKey(
key="video",
transform=K.VideoSequential(
K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
data_format="BCTHW",
same_on_frame=False
)
),
]),
}

# 3. Load the data from directories.
datamodule = VideoClassificationData.from_paths(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
train_data_path="data/kinetics/train",
val_data_path="data/kinetics/val",
predict_data_path="data/kinetics/predict",
clip_sampler="uniform",
clip_duration=2,
video_sampler=SequentialSampler,
decode_audio=False,
train_transform=train_transform
)

# 4. List the available models
print(VideoClassifier.available_models())
# out: ['efficient_x3d_s', 'efficient_x3d_xs', ... ,slowfast_r50', 'x3d_m', 'x3d_s', 'x3d_xs']

# 5. Build the model
model = VideoClassifier(model="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)

# 6. Train the model
trainer = flash.Trainer(fast_dev_run=True)

# 6. Finetune the model
trainer.finetune(model, datamodule=datamodule)

predictions = model.predict("data/kinetics/train/archery/-1q7jA3DXQM_000005_000015.mp4")
print(predictions)


------

*************
API reference
*************

.. _video_classifier:

VideoClassifier
---------------

.. autoclass:: flash.video.VideoClassifier
:members:
:exclude-members: forward

.. _video_classification_data:

VideoClassificationData
-----------------------

.. autoclass:: flash.video.VideoClassificationData

.. automethod:: flash.video.VideoClassificationData.from_paths
1 change: 0 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class Labels(Classes):
def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5):
super().__init__(multi_label=multi_label, threshold=threshold)
self._labels = labels
self.set_state(ClassificationState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None
Expand Down
1 change: 1 addition & 0 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
train_bn: Whether to train Batch Norm layer

"""
super().__init__()

self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names
self.train_bn = train_bn
Expand Down
12 changes: 11 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import functools
from importlib import import_module
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import inspect
import torch
import torchmetrics
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -325,6 +325,9 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
getattr(data_pipeline, '_postprocess_pipeline', None),
getattr(data_pipeline, '_serializer', None),
)
self._preprocess.state_dict()
if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None):
self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore

@property
def preprocess(self) -> Preprocess:
Expand Down Expand Up @@ -394,6 +397,13 @@ def available_models(cls) -> List[str]:
return []
return registry.available_keys()

@classmethod
def get_model_details(cls, key) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
if registry is None:
return []
return [v for v in inspect.signature(registry.get(key)).parameters.items()]

@classmethod
def available_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None)
Expand Down
9 changes: 3 additions & 6 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import subprocess
from abc import ABC, ABCMeta, abstractclassmethod, abstractmethod, abstractproperty, abstractstaticmethod
from abc import ABC, abstractclassmethod, abstractmethod
from dataclasses import dataclass
from importlib import import_module
from operator import truediv
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, TYPE_CHECKING, TypeVar

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -339,6 +335,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
preprocess_state_dict["_meta"]["class_name"] = self.__class__.__name__
preprocess_state_dict["_meta"]["_state"] = self._state
destination['preprocess.state_dict'] = preprocess_state_dict
self._ddp_params_and_buffers_to_ignore = ['preprocess.state_dict']
return super()._save_to_state_dict(destination, prefix, keep_vars)

def _check_transforms(self, transform: Optional[Dict[str, Callable]],
Expand Down
4 changes: 2 additions & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
max_length: int,
target: str,
filetype: str,
train_file: Optional[str],
label_to_class_mapping: Optional[Dict[str, int]],
train_file: Optional[str] = None,
label_to_class_mapping: Optional[Dict[str, int]] = None,
):
"""
This class contains the preprocessing logic for text classification
Expand Down
Loading