diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index f038302e428..ed3dcfcbbef 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -119,6 +119,16 @@ def test_compute_clips_for_video(self): self.assertTrue(clips.equal(idxs)) self.assertTrue(idxs.flatten().equal(resampled_idxs)) + # case 3: frames aren't enough for a clip + num_frames = 32 + orig_fps = 30 + new_fps = 13 + with self.assertWarns(UserWarning): + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, + orig_fps, new_fps) + self.assertEqual(len(clips), 0) + self.assertEqual(len(idxs), 0) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 91b858d6b91..d3e8a4d179a 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -1,5 +1,6 @@ import bisect import math +import warnings from fractions import Fraction from typing import List @@ -204,6 +205,9 @@ def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): ) video_pts = video_pts[idxs] clips = unfold(video_pts, num_frames, step) + if not clips.numel(): + warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and " + "frames between clips. The video (and potentially others) will be skipped.") if isinstance(idxs, slice): idxs = [idxs] * len(clips) else: