diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index b80510c19..d5b69f0b4 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -29,13 +29,14 @@ class which inherits from `AppCommand` (or a more specialized class such as import attr import operator import os +import cv2 import re import sys import subprocess from enum import Enum from glob import glob -from pathlib import PurePath +from pathlib import PurePath, Path from typing import Callable, Dict, Iterator, List, Optional, Type, Tuple import numpy as np @@ -1516,30 +1517,107 @@ def do_action(context: CommandContext, params: dict): class ReplaceVideo(EditCommand): - topics = [UpdateTopic.video] + topics = [UpdateTopic.video, UpdateTopic.frame] @staticmethod - def do_action(context: CommandContext, params: dict): - new_paths = params["new_video_paths"] + def do_action(context: CommandContext, params: dict) -> bool: + + import_list = params["import_list"] + + for import_item, video in import_list: + import_params = import_item["params"] + + # TODO: Will need to create a new backend if import has different extension. + if ( + Path(video.backend.filename).suffix + != Path(import_params["filename"]).suffix + ): + raise TypeError( + "Importing videos with different extensions is not supported." + ) + video.backend.reset(**import_params) + + # Remove frames in video past last frame index + last_vid_frame = video.last_frame_idx + lfs: List[LabeledFrame] = list(context.labels.get(video)) + if lfs is not None: + lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame] + context.labels.remove_frames(lfs) - for video, new_path in zip(context.labels.videos, new_paths): - if new_path != video.backend.filename: - video.backend.filename = new_path - video.backend.reset() + # Update seekbar and video length through callbacks + context.state.emit("video") @staticmethod def ask(context: CommandContext, params: dict) -> bool: """Shows gui for replacing videos in project.""" - paths = [video.backend.filename for video in context.labels.videos] - okay = MissingFilesDialog(filenames=paths, replace=True).exec_() + def _get_truncation_message(truncation_messages, path, video): + reader = cv2.VideoCapture(path) + last_vid_frame = int(reader.get(cv2.CAP_PROP_FRAME_COUNT)) + lfs: List[LabeledFrame] = list(context.labels.get(video)) + if lfs is not None: + lfs.sort(key=lambda lf: lf.frame_idx) + last_lf_frame = lfs[-1].frame_idx + lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame] + + # Message to warn users that labels will be removed if proceed + if last_lf_frame > last_vid_frame: + message = ( + "

Warning: Replacing this video will " + f"remove {len(lfs)} labeled frames.

" + f"

Current video: {Path(video.filename).name}" + f" (last label at frame {last_lf_frame})
" + f"Replacement video: {Path(path).name}" + f" ({last_vid_frame} frames)

" + ) + # Assumes that a project won't import the same video multiple times + truncation_messages[path] = message + + return truncation_messages + + # Warn user: newly added labels will be discarded if project is not saved + if not context.state["filename"] or context.state["has_changes"]: + QMessageBox( + text=("You have unsaved changes. Please save before replacing videos.") + ).exec_() + return False + # Select the videos we want to swap + old_paths = [video.backend.filename for video in context.labels.videos] + paths = list(old_paths) + okay = MissingFilesDialog(filenames=paths, replace=True).exec_() if not okay: return False - params["new_video_paths"] = paths + # Only return an import list for videos we swap + new_paths = [ + (path, video_idx) + for video_idx, (path, old_path) in enumerate(zip(paths, old_paths)) + if path != old_path + ] - return True + new_paths = [] + old_videos = dict() + all_videos = context.labels.videos + truncation_messages = dict() + for video_idx, (path, old_path) in enumerate(zip(paths, old_paths)): + if path != old_path: + new_paths.append(path) + old_videos[path] = all_videos[video_idx] + truncation_messages = _get_truncation_message( + truncation_messages, path, video=all_videos[video_idx] + ) + + import_list = ImportVideos().ask( + filenames=new_paths, messages=truncation_messages + ) + # Remove videos that no longer correlate to filenames. + old_videos_to_replace = [ + old_videos[imp["params"]["filename"]] for imp in import_list + ] + params["import_list"] = zip(import_list, old_videos_to_replace) + + return len(import_list) > 0 class RemoveVideo(EditCommand): diff --git a/sleap/gui/dialogs/importvideos.py b/sleap/gui/dialogs/importvideos.py index 17d899002..127b46e5a 100644 --- a/sleap/gui/dialogs/importvideos.py +++ b/sleap/gui/dialogs/importvideos.py @@ -47,7 +47,11 @@ class ImportVideos: def __init__(self): self.result = [] - def ask(self, filenames: Optional[List[str]] = None): + def ask( + self, + filenames: Optional[List[str]] = None, + messages: Optional[Dict[str, str]] = None, + ): """Runs the import UI. 1. Show file selection dialog. @@ -59,6 +63,8 @@ def ask(self, filenames: Optional[List[str]] = None): Returns: List with dict of the parameters for each file to import. """ + messages = dict() if messages is None else messages + if filenames is None: filenames, filter = FileDialog.openMultiple( None, @@ -66,10 +72,12 @@ def ask(self, filenames: Optional[List[str]] = None): ".", # initial path "Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)", ) + if len(filenames) > 0: - importer = ImportParamDialog(filenames) + importer = ImportParamDialog(filenames, messages) importer.accepted.connect(lambda: importer.get_data(self.result)) importer.exec_() + return self.result @classmethod @@ -91,9 +99,12 @@ class ImportParamDialog(QDialog): filenames (list): List of files we want to import. """ - def __init__(self, filenames: List[str], *args, **kwargs): + def __init__( + self, filenames: List[str], messages: Dict[str, str] = None, *args, **kwargs + ): super(ImportParamDialog, self).__init__(*args, **kwargs) + messages = dict() if messages is None else messages self.import_widgets = [] self.setWindowTitle("Video Import Options") @@ -135,7 +146,6 @@ def __init__(self, filenames: List[str], *args, **kwargs): outer_layout = QVBoxLayout() scroll_widget = QScrollArea() - # scroll_widget.setWidgetResizable(False) scroll_widget.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) scroll_widget.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) @@ -152,7 +162,10 @@ def __init__(self, filenames: List[str], *args, **kwargs): this_type = import_type break if this_type is not None: - import_item_widget = ImportItemWidget(file_name, this_type) + message = messages[file_name] if file_name in messages else "" + import_item_widget = ImportItemWidget( + file_name, this_type, message=message + ) self.import_widgets.append(import_item_widget) scroll_layout.addWidget(import_item_widget) else: @@ -197,6 +210,7 @@ def __init__(self, filenames: List[str], *args, **kwargs): button_layout.addWidget(import_button) outer_layout.addLayout(button_layout) + self.adjustSize() self.setLayout(outer_layout) @@ -220,14 +234,6 @@ def get_data(self, import_result=None): import_result.append(import_item.get_data()) return import_result - def boundingRect(self) -> QRectF: - """Method required by Qt.""" - return QRectF() - - def paint(self, painter, option, widget=None): - """Method required by Qt.""" - pass - def set_all_grayscale(self): for import_item in self.import_widgets: widget_elements = import_item.options_widget.widget_elements @@ -269,7 +275,14 @@ class ImportItemWidget(QFrame): import_type (dict): Data about user-selectable import parameters. """ - def __init__(self, file_path: str, import_type: dict, *args, **kwargs): + def __init__( + self, + file_path: str, + import_type: Dict[str, Any], + message: str = "", + *args, + **kwargs, + ): super(ImportItemWidget, self).__init__(*args, **kwargs) self.file_path = file_path @@ -287,6 +300,9 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs): self.options_widget = ImportParamWidget( parent=self, file_path=self.file_path, import_type=self.import_type ) + + self.message_widget = MessageWidget(parent=self, message=message) + self.preview_widget = VideoPreviewWidget(parent=self) self.preview_widget.setFixedSize(200, 200) @@ -295,6 +311,7 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs): ) inner_layout.addWidget(self.options_widget) + inner_layout.addWidget(self.message_widget) inner_layout.addWidget(self.preview_widget) import_item_layout.addLayout(inner_layout) self.setLayout(import_item_layout) @@ -379,7 +396,7 @@ class ImportParamWidget(QWidget): changed = Signal() - def __init__(self, file_path: str, import_type: dict, *args, **kwargs): + def __init__(self, file_path: str, import_type: Dict[str, Any], *args, **kwargs): super(ImportParamWidget, self).__init__(*args, **kwargs) self.file_path = file_path @@ -516,13 +533,17 @@ def _find_h5_datasets(self, data_path, data_object) -> list: ) return options - def boundingRect(self) -> QRectF: - """Method required by Qt.""" - return QRectF() - def paint(self, painter, option, widget=None): - """Method required by Qt.""" - pass +class MessageWidget(QWidget): + """Widget to show message.""" + + def __init__(self, message: str = str(), *args, **kwargs): + super().__init__(*args, **kwargs) + self.message = QLabel(message) + self.message.setStyleSheet("color: red") + self.layout = QVBoxLayout() + self.layout.addWidget(self.message) + self.setLayout(self.layout) class VideoPreviewWidget(QWidget): @@ -585,34 +606,28 @@ def plot(self, idx=0): # Display image self.view.setImage(image) - def boundingRect(self) -> QRectF: - """Method required by Qt.""" - return QRectF() - - def paint(self, painter, option, widget=None): - """Method required by Qt.""" - pass +# if __name__ == "__main__": -if __name__ == "__main__": +# app = QApplication([]) - app = QApplication([]) +# # import_list = ImportVideos().ask() - # import_list = ImportVideos().ask() +# filenames = [ +# "tests/data/videos/centered_pair_small.mp4", +# "tests/data/videos/small_robot.mp4", +# ] - filenames = [ - "tests/data/videos/centered_pair_small.mp4", - "tests/data/videos/small_robot.mp4", - ] +# messages = {"tests/data/videos/small_robot.mp4": "Testing messages"} - import_list = [] - importer = ImportParamDialog(filenames) - importer.accepted.connect(lambda: importer.get_data(import_list)) - importer.exec_() +# import_list = [] +# importer = ImportParamDialog(filenames, messages=messages) +# importer.accepted.connect(lambda: importer.get_data(import_list)) +# importer.exec_() - for import_item in import_list: - vid = import_item["video_class"](**import_item["params"]) - print( - "Imported video data: (%d, %d), %d f, %d c" - % (vid.width, vid.height, vid.frames, vid.channels) - ) +# for import_item in import_list: +# vid = import_item["video_class"](**import_item["params"]) +# print( +# "Imported video data: (%d, %d), %d f, %d c" +# % (vid.width, vid.height, vid.frames, vid.channels) +# ) diff --git a/sleap/gui/dialogs/missingfiles.py b/sleap/gui/dialogs/missingfiles.py index dc8237f5c..eb78fd732 100644 --- a/sleap/gui/dialogs/missingfiles.py +++ b/sleap/gui/dialogs/missingfiles.py @@ -4,6 +4,7 @@ import os +from pathlib import Path, PurePath from typing import Callable, List from PySide2 import QtWidgets, QtCore, QtGui @@ -47,6 +48,7 @@ def __init__( self.filenames = filenames self.missing = missing + self.replace = replace missing_count = sum(missing) @@ -88,11 +90,22 @@ def locateFile(self, idx: int): caption = f"Please locate {old_filename}..." filters = [f"Missing file type (*{old_ext})", "Any File (*.*)"] + filters = [filters[0]] if self.replace else filters new_filename, _ = FileDialog.open( None, dir=None, caption=caption, filter=";;".join(filters) ) - if new_filename: + path_new_filename = Path(new_filename) + paths = [str(PurePath(fn)) for fn in self.filenames] + if str(path_new_filename) in paths: + # Do not allow same video to be imported more than once. + QtWidgets.QMessageBox( + text=( + f"The file {path_new_filename.name} cannot be added to the " + "project multiple times." + ) + ).exec_() + elif new_filename: # Try using this change to find other missing files self.setFilename(idx, new_filename) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 414b68132..f88373cef 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -2,19 +2,23 @@ CommandContext, ImportDeepLabCutFolder, ExportAnalysisFile, + ReplaceVideo, get_new_version_filename, ) from sleap.io.dataset import Labels from sleap.io.pathutils import fix_path_separator from sleap.io.video import Video from sleap.io.convert import default_analysis_filename -from sleap.instance import Instance +from sleap.instance import Instance, LabeledFrame from tests.info.test_h5 import extract_meta_hdf5 from tests.io.test_video import assert_video_params from pathlib import PurePath, Path +from typing import List + import shutil +import pytest def test_delete_user_dialog(centered_pair_predictions): @@ -254,3 +258,52 @@ def test_ToggleGrayscale(centered_pair_predictions: Labels): # Toggle grayscale back to "grayscale" context.toggleGrayscale() assert_video_params(video=video, filename=filename, grayscale=grayscale) + + +def test_ReplaceVideo( + centered_pair_predictions: Labels, small_robot_mp4_vid: Video, hdf5_vid: Video +): + """Test functionality for ToggleGrayscale on mp4/avi video""" + + def get_last_lf_in_video(labels, video): + lfs: List[LabeledFrame] = list(labels.get(videos[0])) + lfs.sort(key=lambda lf: lf.frame_idx) + return lfs[-1].frame_idx + + def replace_video( + new_video: Video, videos_to_replace: List[Video], context: CommandContext + ): + # Video to be imported + new_video_filename = new_video.backend.filename + + # Replace the video + import_item_list = [ + {"params": {"filename": new_video_filename, "grayscale": True}} + ] + params = {"import_list": zip(import_item_list, videos_to_replace)} + ReplaceVideo.do_action(context=context, params=params) + return new_video_filename + + # Labels and video to be replaced + labels = centered_pair_predictions + context = CommandContext.from_labels(labels) + videos = labels.videos + last_lf_frame = get_last_lf_in_video(labels, videos[0]) + + # Replace the video + new_video_filename = replace_video(small_robot_mp4_vid, videos, context) + + # Ensure video backend was replaced + video = labels.video + assert len(labels.videos) == 1 + assert video.backend.grayscale == True + assert video.backend.filename == new_video_filename + + # Ensure labels were truncated (Original video was fully labeled) + new_last_lf_frame = get_last_lf_in_video(labels, video) + # Original video was fully labeled + assert new_last_lf_frame == labels.video.last_frame_idx + + # Attempt to replace an mp4 with an hdf5 video + with pytest.raises(TypeError): + replace_video(hdf5_vid, labels.videos, context)