diff --git a/changelog.d/20240503_110728_maria_fix_task_creation_from_video_with_no_valid_keyframes.md b/changelog.d/20240503_110728_maria_fix_task_creation_from_video_with_no_valid_keyframes.md new file mode 100644 index 000000000000..7ff0a62d96a7 --- /dev/null +++ b/changelog.d/20240503_110728_maria_fix_task_creation_from_video_with_no_valid_keyframes.md @@ -0,0 +1,4 @@ +### Fixed + +- Task creation from a video file without keyframes allowing for random iteration + () diff --git a/utils/dataset_manifest/core.py b/utils/dataset_manifest/core.py index dc050687b390..b336c8a33dca 100644 --- a/utils/dataset_manifest/core.py +++ b/utils/dataset_manifest/core.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT from enum import Enum -from io import StringIO +from io import StringIO, BytesIO import av import json import os @@ -13,12 +13,11 @@ from contextlib import closing from PIL import Image from json.decoder import JSONDecodeError -from io import BytesIO -from .errors import InvalidManifestError, InvalidVideoFrameError +from .errors import InvalidManifestError, InvalidVideoError from .utils import SortingMethod, md5_hash, rotate_image, sort -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional, Iterator, Tuple class VideoStreamReader: def __init__(self, source_path, chunk_size, force): @@ -33,7 +32,7 @@ def __init__(self, source_path, chunk_size, force): for frame in packet.decode(): # check type of first frame if not frame.pict_type.name == 'I': - raise InvalidVideoFrameError('First frame is not key frame') + raise InvalidVideoError('The first frame is not a key frame') # get video resolution if video_stream.metadata.get('rotate'): @@ -75,40 +74,66 @@ def validate_key_frame(self, container, video_stream, key_frame): return False return True - def __iter__(self): - with closing(av.open(self.source_path, mode='r')) as container: - video_stream = self._get_video_stream(container) - frame_pts, frame_dts = -1, -1 - index, key_frame_number = 0, 0 - for packet in container.demux(video_stream): + def __iter__(self) -> Iterator[Union[int, Tuple[int, int, str]]]: + """ + Iterate over video frames and yield key frames or indexes. + + Yields: + Union[Tuple[int, int, str], int]: (frame index, frame timestamp, frame MD5) or frame index. + """ + # Open containers for reading frames and checking movement on them + with ( + closing(av.open(self.source_path, mode='r')) as reading_container, + closing(av.open(self.source_path, mode='r')) as checking_container + ): + reading_v_stream = self._get_video_stream(reading_container) + checking_v_stream = self._get_video_stream(checking_container) + prev_pts: Optional[int] = None + prev_dts: Optional[int] = None + index, key_frame_count = 0, 0 + + for packet in reading_container.demux(reading_v_stream): for frame in packet.decode(): - if None not in {frame.pts, frame_pts} and frame.pts <= frame_pts: - raise InvalidVideoFrameError('Invalid pts sequences') - if None not in {frame.dts, frame_dts} and frame.dts <= frame_dts: - raise InvalidVideoFrameError('Invalid dts sequences') - frame_pts, frame_dts = frame.pts, frame.dts + # Check PTS and DTS sequences for validity + if None not in {frame.pts, prev_pts} and frame.pts <= prev_pts: + raise InvalidVideoError('Detected non-increasing PTS sequence in the video') + if None not in {frame.dts, prev_dts} and frame.dts <= prev_dts: + raise InvalidVideoError('Detected non-increasing DTS sequence in the video') + prev_pts, prev_dts = frame.pts, frame.dts if frame.key_frame: - key_frame_number += 1 - ratio = (index + 1) // key_frame_number - - if ratio >= self._upper_bound and not self._force: - raise AssertionError('Too few keyframes') - - key_frame = { - 'index': index, + key_frame_data = { 'pts': frame.pts, - 'md5': md5_hash(frame) + 'md5': md5_hash(frame), } - with closing(av.open(self.source_path, mode='r')) as checked_container: - checked_container.seek(offset=key_frame['pts'], stream=video_stream) - isValid = self.validate_key_frame(checked_container, video_stream, key_frame) - if isValid: - yield (index, key_frame['pts'], key_frame['md5']) + # Check that it is possible to seek to this key frame using frame.pts + checking_container.seek( + offset=key_frame_data['pts'], + stream=checking_v_stream, + ) + is_valid_key_frame = self.validate_key_frame( + checking_container, + checking_v_stream, + key_frame_data, + ) + + if is_valid_key_frame: + key_frame_count += 1 + yield (index, key_frame_data['pts'], key_frame_data['md5']) + else: + yield index else: yield index + index += 1 + key_frame_ratio = index // (key_frame_count or 1) + + # Check if the number of key frames meets the upper bound + if key_frame_ratio >= self._upper_bound and not self._force: + raise InvalidVideoError('The number of keyframes is not enough for smooth iteration over the video') + + # Update frames number if not already set if not self._frames_number: self._frames_number = index @@ -317,6 +342,9 @@ def __getitem__(self, number): def __len__(self): return len(self._index) + def is_empty(self) -> bool: + return not len(self) + class _ManifestManager(ABC): BASE_INFORMATION = { 'version' : 1, @@ -405,10 +433,12 @@ def manifest(self): return self._manifest def __len__(self): - if hasattr(self, '_index'): - return len(self._index) - else: - return None + return len(self._index) + + def is_empty(self) -> bool: + if self._index.is_empty(): + self._index.load() + return self._index.is_empty() def __getitem__(self, item): if isinstance(item, slice): @@ -482,6 +512,9 @@ def create(self, *, _tqdm=None): # pylint: disable=arguments-differ self.set_index() + if self.is_empty() and not self._reader._force: + raise InvalidManifestError('Empty manifest file has been created') + def partial_update(self, number, properties): pass diff --git a/utils/dataset_manifest/errors.py b/utils/dataset_manifest/errors.py index 516640bad1fc..ea2847c63308 100644 --- a/utils/dataset_manifest/errors.py +++ b/utils/dataset_manifest/errors.py @@ -7,7 +7,7 @@ class BasicError(Exception): The basic exception type for all exceptions in the library """ -class InvalidVideoFrameError(BasicError): +class InvalidVideoError(BasicError): """ Indicates an invalid video frame """