Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor Video Inputs (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 17, 2022
1 parent d01b5a0 commit 1bbd7d7
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 149 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for COCO annotations with non-default keypoint labels to `KeypointDetectionData.from_coco` ([#1102](https://github.com/PyTorchLightning/lightning-flash/pull/1102))

- Added support for `from_csv` and `from_data_frame` to `VideoClassificationData` ([#1117](https://github.com/PyTorchLightning/lightning-flash/pull/1117))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
6 changes: 5 additions & 1 deletion docs/source/api/video.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ ______________

classification.input.VideoClassificationInput
classification.input.VideoClassificationFiftyOneInput
classification.input.VideoClassificationPathsPredictInput
classification.input.VideoClassificationFoldersInput
classification.input.VideoClassificationFilesInput
classification.input.VideoClassificationDataFrameInput
classification.input.VideoClassificationCSVInput
classification.input.VideoClassificationPathsPredictInput
classification.input.VideoClassificationDataFramePredictInput
classification.input.VideoClassificationCSVPredictInput
classification.input_transform.VideoClassificationInputTransform
122 changes: 121 additions & 1 deletion flash/video/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import pandas as pd
import torch
from torch.utils.data import Sampler

from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioVideoClassificationInput
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, requires
from flash.core.utilities.stages import RunningStage
from flash.video.classification.input import (
VideoClassificationCSVInput,
VideoClassificationCSVPredictInput,
VideoClassificationDataFrameInput,
VideoClassificationDataFramePredictInput,
VideoClassificationFiftyOneInput,
VideoClassificationFilesInput,
VideoClassificationFoldersInput,
Expand Down Expand Up @@ -136,6 +142,120 @@ def from_folders(
**data_module_kwargs,
)

@classmethod
def from_data_frame(
cls,
input_field: str,
target_fields: Optional[Union[str, Sequence[str]]] = None,
train_data_frame: Optional[pd.DataFrame] = None,
train_images_root: Optional[str] = None,
train_resolver: Optional[Callable[[str, str], str]] = None,
val_data_frame: Optional[pd.DataFrame] = None,
val_images_root: Optional[str] = None,
val_resolver: Optional[Callable[[str, str], str]] = None,
test_data_frame: Optional[pd.DataFrame] = None,
test_images_root: Optional[str] = None,
test_resolver: Optional[Callable[[str, str], str]] = None,
predict_data_frame: Optional[pd.DataFrame] = None,
predict_images_root: Optional[str] = None,
predict_resolver: Optional[Callable[[str, str], str]] = None,
train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
clip_sampler: Union[str, "ClipSampler"] = "random",
clip_duration: float = 2,
clip_sampler_kwargs: Dict[str, Any] = None,
video_sampler: Type[Sampler] = torch.utils.data.RandomSampler,
decode_audio: bool = False,
decoder: str = "pyav",
input_cls: Type[Input] = VideoClassificationDataFrameInput,
predict_input_cls: Type[Input] = VideoClassificationDataFramePredictInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "VideoClassificationData":
ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
clip_duration=clip_duration,
clip_sampler_kwargs=clip_sampler_kwargs,
video_sampler=video_sampler,
decode_audio=decode_audio,
decoder=decoder,
)

train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver)
val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver)
test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

@classmethod
def from_csv(
cls,
input_field: str,
target_fields: Optional[Union[str, List[str]]] = None,
train_file: Optional[PATH_TYPE] = None,
train_images_root: Optional[PATH_TYPE] = None,
train_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
val_file: Optional[PATH_TYPE] = None,
val_images_root: Optional[PATH_TYPE] = None,
val_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
test_file: Optional[str] = None,
test_images_root: Optional[str] = None,
test_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
predict_file: Optional[str] = None,
predict_images_root: Optional[str] = None,
predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
train_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = VideoClassificationInputTransform,
clip_sampler: Union[str, "ClipSampler"] = "random",
clip_duration: float = 2,
clip_sampler_kwargs: Dict[str, Any] = None,
video_sampler: Type[Sampler] = torch.utils.data.RandomSampler,
decode_audio: bool = False,
decoder: str = "pyav",
input_cls: Type[Input] = VideoClassificationCSVInput,
predict_input_cls: Type[Input] = VideoClassificationCSVPredictInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "VideoClassificationData":
ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
clip_sampler=clip_sampler,
clip_duration=clip_duration,
clip_sampler_kwargs=clip_sampler_kwargs,
video_sampler=video_sampler,
decode_audio=decode_audio,
decoder=decoder,
)

train_data = (train_file, input_field, target_fields, train_images_root, train_resolver)
val_data = (val_file, input_field, target_fields, val_images_root, val_resolver)
test_data = (test_file, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_file, input_field, predict_images_root, predict_resolver)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)

@classmethod
@requires("fiftyone")
def from_fiftyone(
Expand Down
Loading

0 comments on commit 1bbd7d7

Please sign in to comment.