Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
upgrade pytorchvideo to 0.1.2 (#604)
Browse files Browse the repository at this point in the history
* add weights path

* add available weights

* remove weight path

* add tests ✅

* fix

* update

* add str pretrained

* add test ✅

* fix

* Update flash/image/segmentation/heads.py

* Update CHANGELOG.md

* upgrade pytorchvideo

* Update flash/video/classification/data.py

Co-authored-by: Jirka Borovec <[email protected]>

* add annotation

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Jul 20, 2021
1 parent 97f6ee3 commit 08b56dd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions flash/video/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
if _PYTORCHVIDEO_AVAILABLE:
from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.data.encoded_video_dataset import EncodedVideoDataset, labeled_encoded_video_dataset
from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample
from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip
else:
ClipSampler, EncodedVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None
ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None

_PYTORCHVIDEO_DATA = Dict[str, Union[str, torch.Tensor, int, float, List]]

Expand All @@ -68,7 +68,7 @@ def __init__(
self.decode_audio = decode_audio
self.decoder = decoder

def load_data(self, data: str, dataset: Optional[Any] = None) -> 'EncodedVideoDataset':
def load_data(self, data: str, dataset: Optional[Any] = None) -> 'LabeledVideoDataset':
ds = self._make_encoded_video_dataset(data)
if self.training:
label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels}
Expand All @@ -82,14 +82,14 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[DefaultDataKeys.METADATA] = {"filepath": video_path}
return sample

def _encoded_video_to_dict(self, video) -> Dict[str, Any]:
def _encoded_video_to_dict(self, video, annotation: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
(
clip_start,
clip_end,
clip_index,
aug_index,
is_last_clip,
) = self.clip_sampler(0.0, video.duration)
) = self.clip_sampler(0.0, video.duration, annotation)

loaded_clip = video.get_clip(clip_start, clip_end)

Expand All @@ -115,7 +115,7 @@ def _encoded_video_to_dict(self, video) -> Dict[str, Any]:
} if audio_samples is not None else {}),
}

def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset':
def _make_encoded_video_dataset(self, data) -> 'LabeledVideoDataset':
raise NotImplementedError("Subclass must implement _make_encoded_video_dataset()")


Expand All @@ -139,8 +139,8 @@ def __init__(
extensions=("mp4", "avi"),
)

def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset':
ds: EncodedVideoDataset = labeled_encoded_video_dataset(
def _make_encoded_video_dataset(self, data) -> 'LabeledVideoDataset':
ds: LabeledVideoDataset = labeled_video_dataset(
pathlib.Path(data),
self.clip_sampler,
video_sampler=self.video_sampler,
Expand Down Expand Up @@ -178,7 +178,7 @@ def __init__(
def label_cls(self):
return fol.Classification

def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDataset':
def _make_encoded_video_dataset(self, data: SampleCollection) -> 'LabeledVideoDataset':
classes = self._get_classes(data)
label_to_class_mapping = dict(enumerate(classes))
class_to_label_mapping = {c: lab for lab, c in label_to_class_mapping.items()}
Expand All @@ -188,7 +188,7 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDa
targets = [class_to_label_mapping[lab] for lab in labels]
labeled_video_paths = LabeledVideoPaths(list(zip(filepaths, targets)))

ds: EncodedVideoDataset = EncodedVideoDataset(
ds: LabeledVideoDataset = LabeledVideoDataset(
labeled_video_paths,
self.clip_sampler,
video_sampler=self.video_sampler,
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_video.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchvision
Pillow>=7.2
kornia>=0.5.1,<0.5.4
pytorchvideo==0.1.0
pytorchvideo==0.1.2

0 comments on commit 08b56dd

Please sign in to comment.