diff --git a/sleap/nn/utils.py b/sleap/nn/utils.py index 39831f991..777a9bf07 100644 --- a/sleap/nn/utils.py +++ b/sleap/nn/utils.py @@ -309,6 +309,7 @@ class VideoLoader: dataset: str = None input_format: str = None grayscale: bool = False + dummy: bool = False chunk_size: int = 32 prefetch_chunks: int = 1 frame_inds: Optional[List[int]] = None @@ -367,8 +368,12 @@ def _load_video(self, filename) -> "Video": ) def load_frames(self, frame_inds): - local_vid = self._load_video(self.video.filename) - imgs = local_vid[np.array(frame_inds).astype("int64")] + if self.dummy: + dummy_shape = (len(frame_inds), *self._shape[1:]) + imgs = np.zeros(dummy_shape, dtype="int8") + else: + local_vid = self._load_video(self.video.filename) + imgs = local_vid[np.array(frame_inds).astype("int64")] return imgs def tf_load_frames(self, frame_inds): diff --git a/tests/nn/test_utils.py b/tests/nn/test_utils.py index 8701a77a5..488ae1da7 100644 --- a/tests/nn/test_utils.py +++ b/tests/nn/test_utils.py @@ -1,4 +1,5 @@ from sleap.nn.utils import VideoLoader +import numpy as np def test_grayscale_video(): @@ -7,3 +8,11 @@ def test_grayscale_video(): vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", grayscale=True) assert vid.shape[-1] == 1 + + +def test_dummy_video(): + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", dummy=True) + + x = vid.load_frames([1, 3, 5]) + assert x.shape == (3, 320, 560, 3) + assert np.all(x == 0)