diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index 9228ccc610bd..7bd60a8a9fbb 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -1,21 +1,21 @@ -# Copyright (C) 2019 Intel Corporation +# Copyright (C) 2020 Intel Corporation # # SPDX-License-Identifier: MIT +import itertools import math -from io import BytesIO from enum import Enum -import itertools +from io import BytesIO import numpy as np from PIL import Image from cvat.apps.engine.media_extractors import VideoReader, ZipReader -from cvat.apps.engine.models import DataChoice from cvat.apps.engine.mime_types import mimetypes +from cvat.apps.engine.models import DataChoice -class FrameProvider(): +class FrameProvider: class Quality(Enum): COMPRESSED = 0 ORIGINAL = 100 @@ -25,26 +25,33 @@ class Type(Enum): PIL = 1 NUMPY_ARRAY = 2 - def __init__(self, db_data): - self._db_data = db_data - if db_data.compressed_chunk_type == DataChoice.IMAGESET: - self._compressed_chunk_reader_class = ZipReader - elif db_data.compressed_chunk_type == DataChoice.VIDEO: - self._compressed_chunk_reader_class = VideoReader - else: - raise Exception('Unsupported chunk type') + class ChunkLoader: + def __init__(self, reader_class, path_getter): + self.chunk_id = None + self.chunk_reader = None + self.reader_class = reader_class + self.get_chunk_path = path_getter - if db_data.original_chunk_type == DataChoice.IMAGESET: - self._original_chunk_reader_class = ZipReader - elif db_data.original_chunk_type == DataChoice.VIDEO: - self._original_chunk_reader_class = VideoReader - else: - raise Exception('Unsupported chunk type') + def load(self, chunk_id): + if self.chunk_id != chunk_id: + self.chunk_id = chunk_id + self.chunk_reader = self.reader_class([self.get_chunk_path(chunk_id)]) + return self.chunk_reader - self._extracted_compressed_chunk = None - self._compressed_chunk_reader = None - self._extracted_original_chunk = None - self._original_chunk_reader = None + def __init__(self, db_data): + self._db_data = db_data + self._loaders = {} + + reader_class = { + DataChoice.IMAGESET: ZipReader, + DataChoice.VIDEO: VideoReader, + } + self._loaders[self.Quality.COMPRESSED] = self.ChunkLoader( + reader_class[db_data.compressed_chunk_type], + db_data.get_compressed_chunk_path) + self._loaders[self.Quality.ORIGINAL] = self.ChunkLoader( + reader_class[db_data.original_chunk_type], + db_data.get_original_chunk_path) def __len__(self): return self._db_data.size @@ -74,77 +81,41 @@ def _av_frame_to_png_bytes(av_frame): buf.seek(0) return buf - def _get_frame(self, frame_number, chunk_path_getter, extracted_chunk, chunk_reader, reader_class): - _, chunk_number, frame_offset = self._validate_frame_number(frame_number) - chunk_path = chunk_path_getter(chunk_number) - if chunk_number != extracted_chunk: - extracted_chunk = chunk_number - chunk_reader = reader_class([chunk_path]) - - frame, frame_name, _ = next(itertools.islice(chunk_reader, frame_offset, None)) - if reader_class is VideoReader: - return (self._av_frame_to_png_bytes(frame), 'image/png') - - return (frame, mimetypes.guess_type(frame_name)) - - def _get_frames(self, chunk_path_getter, reader_class, out_type): - for chunk_idx in range(math.ceil(self._db_data.size / self._db_data.chunk_size)): - chunk_path = chunk_path_getter(chunk_idx) - chunk_reader = reader_class([chunk_path]) - for frame, _, _ in chunk_reader: - if out_type == self.Type.BUFFER: - yield self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame - elif out_type == self.Type.PIL: - yield frame.to_image() if reader_class is VideoReader else Image.open(frame) - elif out_type == self.Type.NUMPY_ARRAY: - if reader_class is VideoReader: - image = np.array(frame.to_image()) - else: - image = np.array(Image.open(frame)) - if len(image.shape) == 3 and image.shape[2] in {3, 4}: - image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR - yield image - else: - raise Exception('unsupported output type') + def _convert_frame(self, frame, reader_class, out_type): + if out_type == self.Type.BUFFER: + return self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame + elif out_type == self.Type.PIL: + return frame.to_image() if reader_class is VideoReader else Image.open(frame) + elif out_type == self.Type.NUMPY_ARRAY: + if reader_class is VideoReader: + image = np.array(frame.to_image()) + else: + image = np.array(Image.open(frame)) + if len(image.shape) == 3 and image.shape[2] in {3, 4}: + image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR + return image + else: + raise Exception('unsupported output type') def get_preview(self): return self._db_data.get_preview_path() def get_chunk(self, chunk_number, quality=Quality.ORIGINAL): chunk_number = self._validate_chunk_number(chunk_number) - if quality == self.Quality.ORIGINAL: - return self._db_data.get_original_chunk_path(chunk_number) - elif quality == self.Quality.COMPRESSED: - return self._db_data.get_compressed_chunk_path(chunk_number) + return self._loaders[quality].get_chunk_path(chunk_number) def get_frame(self, frame_number, quality=Quality.ORIGINAL): - if quality == self.Quality.ORIGINAL: - return self._get_frame( - frame_number=frame_number, - chunk_path_getter=self._db_data.get_original_chunk_path, - extracted_chunk=self._extracted_original_chunk, - chunk_reader=self._original_chunk_reader, - reader_class=self._original_chunk_reader_class, - ) - elif quality == self.Quality.COMPRESSED: - return self._get_frame( - frame_number=frame_number, - chunk_path_getter=self._db_data.get_compressed_chunk_path, - extracted_chunk=self._extracted_compressed_chunk, - chunk_reader=self._compressed_chunk_reader, - reader_class=self._compressed_chunk_reader_class, - ) + _, chunk_number, frame_offset = self._validate_frame_number(frame_number) + + chunk_reader = self._loaders[quality].load(chunk_number) + + frame, frame_name, _ = next(itertools.islice(chunk_reader, frame_offset, None)) + if self._loaders[quality].reader_class is VideoReader: + return (self._av_frame_to_png_bytes(frame), 'image/png') + return (frame, mimetypes.guess_type(frame_name)) def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER): - if quality == self.Quality.ORIGINAL: - return self._get_frames( - chunk_path_getter=self._db_data.get_original_chunk_path, - reader_class=self._original_chunk_reader_class, - out_type=out_type, - ) - elif quality == self.Quality.COMPRESSED: - return self._get_frames( - chunk_path_getter=self._db_data.get_compressed_chunk_path, - reader_class=self._compressed_chunk_reader_class, - out_type=out_type, - ) + loader = self._loaders[quality] + for chunk_idx in range(math.ceil(self._db_data.size / self._db_data.chunk_size)): + for frame, _, _ in loader.load(chunk_idx): + yield self._convert_frame(frame, loader.reader_class, out_type)