From 1bbd7d701b824b625a6021e3c57c0c0173860617 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 17 Jan 2022 11:55:37 +0000 Subject: [PATCH] Refactor Video Inputs (#1117) --- CHANGELOG.md | 2 + docs/source/api/video.rst | 6 +- flash/video/classification/data.py | 122 +++++++++- flash/video/classification/input.py | 291 +++++++++++++++-------- tests/video/classification/test_model.py | 117 ++++----- 5 files changed, 389 insertions(+), 149 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a9ebd7defb..6133fa51b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for COCO annotations with non-default keypoint labels to `KeypointDetectionData.from_coco` ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102)) +- Added support for `from_csv` and `from_data_frame` to `VideoClassificationData` ([#1117](https://github.com/PyTorchLightning/lightning-flash/pull/1117)) + ### Changed - Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075)) diff --git a/docs/source/api/video.rst b/docs/source/api/video.rst index 4624e52b43..3006ee0261 100644 --- a/docs/source/api/video.rst +++ b/docs/source/api/video.rst @@ -22,7 +22,11 @@ ______________ classification.input.VideoClassificationInput classification.input.VideoClassificationFiftyOneInput - classification.input.VideoClassificationPathsPredictInput classification.input.VideoClassificationFoldersInput classification.input.VideoClassificationFilesInput + classification.input.VideoClassificationDataFrameInput + classification.input.VideoClassificationCSVInput + classification.input.VideoClassificationPathsPredictInput + classification.input.VideoClassificationDataFramePredictInput + classification.input.VideoClassificationCSVPredictInput classification.input_transform.VideoClassificationInputTransform diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index b6256361dc..687b2cc8f4 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +import pandas as pd import torch from torch.utils.data import Sampler @@ -20,10 +21,15 @@ from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE +from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioVideoClassificationInput from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, requires from flash.core.utilities.stages import RunningStage from flash.video.classification.input import ( + VideoClassificationCSVInput, + VideoClassificationCSVPredictInput, + VideoClassificationDataFrameInput, + VideoClassificationDataFramePredictInput, VideoClassificationFiftyOneInput, VideoClassificationFilesInput, VideoClassificationFoldersInput, @@ -136,6 +142,120 @@ def from_folders( **data_module_kwargs, ) + @classmethod + def from_data_frame( + cls, + input_field: str, + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_data_frame: Optional[pd.DataFrame] = None, + train_images_root: Optional[str] = None, + train_resolver: Optional[Callable[[str, str], str]] = None, + val_data_frame: Optional[pd.DataFrame] = None, + val_images_root: Optional[str] = None, + val_resolver: Optional[Callable[[str, str], str]] = None, + test_data_frame: Optional[pd.DataFrame] = None, + test_images_root: Optional[str] = None, + test_resolver: Optional[Callable[[str, str], str]] = None, + predict_data_frame: Optional[pd.DataFrame] = None, + predict_images_root: Optional[str] = None, + predict_resolver: Optional[Callable[[str, str], str]] = None, + train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + input_cls: Type[Input] = VideoClassificationDataFrameInput, + predict_input_cls: Type[Input] = VideoClassificationDataFramePredictInput, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "VideoClassificationData": + ds_kw = dict( + data_pipeline_state=DataPipelineState(), + transform_kwargs=transform_kwargs, + input_transforms_registry=cls.input_transforms_registry, + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) + val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver) + test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) + predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver) + + return cls( + input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), + input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + **data_module_kwargs, + ) + + @classmethod + def from_csv( + cls, + input_field: str, + target_fields: Optional[Union[str, List[str]]] = None, + train_file: Optional[PATH_TYPE] = None, + train_images_root: Optional[PATH_TYPE] = None, + train_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, + val_file: Optional[PATH_TYPE] = None, + val_images_root: Optional[PATH_TYPE] = None, + val_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, + test_file: Optional[str] = None, + test_images_root: Optional[str] = None, + test_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, + predict_file: Optional[str] = None, + predict_images_root: Optional[str] = None, + predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, + train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform, + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + input_cls: Type[Input] = VideoClassificationCSVInput, + predict_input_cls: Type[Input] = VideoClassificationCSVPredictInput, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "VideoClassificationData": + ds_kw = dict( + data_pipeline_state=DataPipelineState(), + transform_kwargs=transform_kwargs, + input_transforms_registry=cls.input_transforms_registry, + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + train_data = (train_file, input_field, target_fields, train_images_root, train_resolver) + val_data = (val_file, input_field, target_fields, val_images_root, val_resolver) + test_data = (test_file, input_field, target_fields, test_images_root, test_resolver) + predict_data = (predict_file, input_field, predict_images_root, predict_resolver) + + return cls( + input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), + input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), + input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), + predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), + **data_module_kwargs, + ) + @classmethod @requires("fiftyone") def from_fiftyone( diff --git a/flash/video/classification/input.py b/flash/video/classification/input.py index 8835063374..268299dd0b 100644 --- a/flash/video/classification/input.py +++ b/flash/video/classification/input.py @@ -11,17 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pathlib -from typing import Any, Dict, Iterable, List, Tuple, Type, Union +import os +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union -import numpy as np +import pandas as pd import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Sampler -from flash.core.data.io.classification_input import ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState from flash.core.data.io.input import DataKeys, Input, IterableInput -from flash.core.data.utilities.paths import list_valid_files +from flash.core.data.utilities.classification import TargetMode +from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets +from flash.core.data.utilities.paths import list_valid_files, make_dataset, PATH_TYPE from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import @@ -35,15 +37,12 @@ if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler from pytorchvideo.data.encoded_video import EncodedVideo - from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset + from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths else: ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None -Label = Union[int, List[int]] - - def _make_clip_sampler( clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, @@ -54,18 +53,174 @@ def _make_clip_sampler( return make_clip_sampler(clip_sampler, clip_duration, **clip_sampler_kwargs) -class VideoClassificationInput(IterableInput): - def load_data(self, dataset: "LabeledVideoDataset") -> "LabeledVideoDataset": - if self.training: - label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in dataset._labeled_videos._paths_and_labels} - self.set_state(ClassificationState(label_to_class_mapping)) - self.num_classes = len(np.unique([s[1]["label"] for s in dataset._labeled_videos])) +class VideoClassificationInput(IterableInput, ClassificationInputMixin): + def load_data( + self, + files: List[PATH_TYPE], + targets: List[Any], + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + ) -> "LabeledVideoDataset": + dataset = LabeledVideoDataset( + LabeledVideoPaths(list(zip(files, targets))), + _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs), + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + if not self.predicting: + self.load_target_metadata([sample[1] for sample in dataset._labeled_videos._paths_and_labels]) return dataset def load_sample(self, sample): + sample["label"] = self.format_target(sample["label"]) return sample +class VideoClassificationFoldersInput(VideoClassificationInput): + def load_data( + self, + path: str, + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + ) -> "LabeledVideoDataset": + return super().load_data( + *make_dataset(path, extensions=("mp4", "avi")), + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + +class VideoClassificationFilesInput(VideoClassificationInput): + def load_data( + self, + paths: List[str], + targets: List[Any], + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + ) -> "LabeledVideoDataset": + return super().load_data( + paths, + targets, + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + +class VideoClassificationDataFrameInput(VideoClassificationInput): + def load_data( + self, + data_frame: pd.DataFrame, + input_key: str, + target_keys: Union[str, List[str]], + root: Optional[PATH_TYPE] = None, + resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + ) -> "LabeledVideoDataset": + result = super().load_data( + resolve_files(data_frame, input_key, root, resolver), + resolve_targets(data_frame, target_keys), + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + # If we had binary multi-class targets then we also know the labels (column names) + if self.training and self.target_mode is TargetMode.MULTI_BINARY and isinstance(target_keys, List): + classification_state = self.get_state(ClassificationState) + self.set_state(ClassificationState(target_keys, classification_state.num_classes)) + + return result + + +class VideoClassificationCSVInput(VideoClassificationDataFrameInput): + def load_data( + self, + csv_file: PATH_TYPE, + input_key: str, + target_keys: Optional[Union[str, List[str]]] = None, + root: Optional[PATH_TYPE] = None, + resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + ) -> "LabeledVideoDataset": + data_frame = read_csv(csv_file) + if root is None: + root = os.path.dirname(csv_file) + return super().load_data( + data_frame, + input_key, + target_keys, + root, + resolver, + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + +class VideoClassificationFiftyOneInput(VideoClassificationInput): + def load_data( + self, + sample_collection: SampleCollection, + clip_sampler: Union[str, "ClipSampler"] = "random", + clip_duration: float = 2, + clip_sampler_kwargs: Dict[str, Any] = None, + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = False, + decoder: str = "pyav", + label_field: str = "ground_truth", + ) -> "LabeledVideoDataset": + label_utilities = FiftyOneLabelUtilities(label_field, fol.Classification) + label_utilities.validate(sample_collection) + + return super().load_data( + sample_collection.values("filepath"), + sample_collection.values(label_field + ".label"), + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + + class VideoClassificationPathsPredictInput(Input): def predict_load_data( self, @@ -117,107 +272,57 @@ def predict_load_sample(self, sample: str) -> Dict[str, Any]: } -class VideoClassificationFoldersInput(VideoClassificationInput): +class VideoClassificationDataFramePredictInput(VideoClassificationPathsPredictInput): def load_data( self, - path: str, + data_frame: pd.DataFrame, + input_key: str, + root: Optional[PATH_TYPE] = None, + resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", - ) -> "LabeledVideoDataset": - dataset = labeled_video_dataset( - pathlib.Path(path), - _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs), - video_sampler=video_sampler, - decode_audio=decode_audio, - decoder=decoder, - ) - return super().load_data(dataset) - - -class VideoClassificationFilesInput(VideoClassificationInput): - def _to_multi_hot(self, label_list: List[int]) -> torch.Tensor: - v = torch.zeros(len(self.labels_set)) - for label in label_list: - v[label] = 1 - return v - - def load_data( - self, - paths: List[str], - labels: List[Union[str, List]], - clip_sampler: Union[str, "ClipSampler"] = "random", - clip_duration: float = 2, - clip_sampler_kwargs: Dict[str, Any] = None, - video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, - decode_audio: bool = False, - decoder: str = "pyav", - ) -> "LabeledVideoDataset": - self.is_multilabel = any(isinstance(label, list) for label in labels) - if self.is_multilabel: - self.labels_set = {label for label_list in labels for label in label_list} - self.label_to_id = {label: i for i, label in enumerate(sorted(self.labels_set))} - self.id_to_label = {i: label for label, i in self.label_to_id.items()} - - encoded_labels = [ - self._to_multi_hot([self.label_to_id[classname] for classname in label_list]) for label_list in labels - ] - - data = list( - zip( - paths, - encoded_labels, - ) - ) - else: - self.labels_set = set(labels) - self.label_to_id = {label: i for i, label in enumerate(sorted(self.labels_set))} - self.id_to_label = {i: label for label, i in self.label_to_id.items()} - data = list(zip(paths, [self.label_to_id[classname] for classname in labels])) - - dataset = LabeledVideoDataset( - LabeledVideoPaths(data), - _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs), + ) -> Iterable[Tuple[str, Any]]: + return super().load_data( + resolve_files(data_frame, input_key, root, resolver), + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, ) - if self.training: - self.set_state(ClassificationState(self.id_to_label)) - self.num_classes = len(self.labels_set) - return dataset -class VideoClassificationFiftyOneInput(VideoClassificationInput): +class VideoClassificationCSVPredictInput(VideoClassificationDataFramePredictInput): def load_data( self, - sample_collection: SampleCollection, + csv_file: PATH_TYPE, + input_key: str, + root: Optional[PATH_TYPE] = None, + resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", - label_field: str = "ground_truth", - ) -> "LabeledVideoDataset": - label_utilities = FiftyOneLabelUtilities(label_field, fol.Classification) - label_utilities.validate(sample_collection) - classes = label_utilities.get_classes(sample_collection) - label_to_class_mapping = dict(enumerate(classes)) - class_to_label_mapping = {c: lab for lab, c in label_to_class_mapping.items()} - - filepaths = sample_collection.values("filepath") - labels = sample_collection.values(label_field + ".label") - targets = [class_to_label_mapping[lab] for lab in labels] - - dataset = LabeledVideoDataset( - LabeledVideoPaths(list(zip(filepaths, targets))), - _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs), + ) -> Iterable[Tuple[str, Any]]: + data_frame = read_csv(csv_file) + if root is None: + root = os.path.dirname(csv_file) + return super().load_data( + data_frame, + input_key, + root, + resolver, + clip_sampler=clip_sampler, + clip_duration=clip_duration, + clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, ) - return super().load_data(dataset) diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index f155037dbf..519f32cb74 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib import os -import random import re import tempfile from pathlib import Path @@ -21,6 +20,7 @@ import pytest import torch +from pandas import DataFrame from torch.utils.data import SequentialSampler import flash @@ -67,7 +67,7 @@ def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=No @contextlib.contextmanager -def mock_encoded_video_dataset_file(): +def mock_video_data_frame(): """ Creates a temporary mock encoded video dataset with 4 videos labeled from 0 - 4. Returns a labeled video file which points to this mock encoded video dataset, the @@ -83,20 +83,23 @@ def mock_encoded_video_dataset_file(): video_file_name_2, data_2, ): - with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f: - f.write(f"{video_file_name_1} 0\n".encode()) - f.write(f"{video_file_name_2} 1\n".encode()) - f.write(f"{video_file_name_1} 2\n".encode()) - f.write(f"{video_file_name_2} 3\n".encode()) - - label_videos = [ - (0, data_1), - (1, data_2), - (2, data_1), - (3, data_2), - ] + data_frame = DataFrame.from_dict( + { + "file": [video_file_name_1, video_file_name_2, video_file_name_1, video_file_name_2], + "target": ["cat", "dog", "cat", "dog"], + } + ) + video_duration = num_frames / fps - yield f.name, label_videos, video_duration + yield data_frame, video_duration + + +@contextlib.contextmanager +def mock_video_csv_file(tmpdir): + with mock_video_data_frame() as (data_frame, video_duration): + csv_file = os.path.join(tmpdir, "data.csv") + data_frame.to_csv(csv_file) + yield csv_file, video_duration @contextlib.contextmanager @@ -119,13 +122,13 @@ def mock_encoded_video_dataset_folder(tmpdir): @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") -def test_video_classifier_finetune_from_folders(tmpdir): - with mock_encoded_video_dataset_file() as (mock_csv, _, total_duration): +def test_video_classifier_finetune_from_folder(tmpdir): + with mock_encoded_video_dataset_folder(tmpdir) as (mock_folder, total_duration): half_duration = total_duration / 2 - 1e-9 datamodule = VideoClassificationData.from_folders( - train_folder=mock_csv, + train_folder=mock_folder, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, @@ -137,10 +140,20 @@ def test_video_classifier_finetune_from_folders(tmpdir): expected_t_shape = 5 assert sample["video"].shape[1] == expected_t_shape - assert len(VideoClassifier.available_backbones()) > 5 + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=torch.cuda.device_count()) + trainer.finetune(model, datamodule=datamodule) + - datamodule = VideoClassificationData.from_folders( - train_folder=mock_csv, +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_video_classifier_finetune_from_files(tmpdir): + with mock_video_data_frame() as (mock_data_frame, total_duration): + + half_duration = total_duration / 2 - 1e-9 + + datamodule = VideoClassificationData.from_files( + train_files=mock_data_frame["file"], + train_targets=mock_data_frame["target"], clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, @@ -148,32 +161,25 @@ def test_video_classifier_finetune_from_folders(tmpdir): batch_size=1, ) + for sample in datamodule.train_dataset.data: + expected_t_shape = 5 + assert sample["video"].shape[1] == expected_t_shape + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule) @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") -def test_video_classifier_finetune_from_files(tmpdir): - with mock_encoded_video_dataset_file() as (mock_csv, _, total_duration): - label_names = ["label_1", "label_2", "label_3", "label_4"] - half_duration = total_duration / 2 - 1e-9 +def test_video_classifier_finetune_from_data_frame(tmpdir): + with mock_video_data_frame() as (mock_data_frame, total_duration): - files = [] - labels = [] - with open(mock_csv) as fin: - for line in fin: - if not line: - break - splits = line.split() - fname = splits[0] - label = label_names[random.randint(0, len(labels))] - files.append(fname) - labels.append(label) + half_duration = total_duration / 2 - 1e-9 - datamodule = VideoClassificationData.from_files( - train_files=files, - train_targets=labels, + datamodule = VideoClassificationData.from_data_frame( + "file", + "target", + train_data_frame=mock_data_frame, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, @@ -185,11 +191,21 @@ def test_video_classifier_finetune_from_files(tmpdir): expected_t_shape = 5 assert sample["video"].shape[1] == expected_t_shape - assert len(VideoClassifier.available_backbones()) > 5 + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=torch.cuda.device_count()) + trainer.finetune(model, datamodule=datamodule) - datamodule = VideoClassificationData.from_files( - train_files=files, - train_targets=labels, + +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_video_classifier_finetune_from_csv(tmpdir): + with mock_video_csv_file(tmpdir) as (mock_csv, total_duration): + + half_duration = total_duration / 2 - 1e-9 + + datamodule = VideoClassificationData.from_csv( + "file", + "target", + train_file=mock_csv, clip_sampler="uniform", clip_duration=half_duration, video_sampler=SequentialSampler, @@ -197,6 +213,10 @@ def test_video_classifier_finetune_from_files(tmpdir): batch_size=1, ) + for sample in datamodule.train_dataset.data: + expected_t_shape = 5 + assert sample["video"].shape[1] == expected_t_shape + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule) @@ -230,17 +250,6 @@ def test_video_classifier_finetune_fiftyone(tmpdir): expected_t_shape = 5 assert sample["video"].shape[1] == expected_t_shape - assert len(VideoClassifier.available_backbones()) > 5 - - datamodule = VideoClassificationData.from_fiftyone( - train_dataset=train_dataset, - clip_sampler="uniform", - clip_duration=half_duration, - video_sampler=SequentialSampler, - decode_audio=False, - batch_size=1, - ) - model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule)