diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py
index b80510c19..d5b69f0b4 100644
--- a/sleap/gui/commands.py
+++ b/sleap/gui/commands.py
@@ -29,13 +29,14 @@ class which inherits from `AppCommand` (or a more specialized class such as
import attr
import operator
import os
+import cv2
import re
import sys
import subprocess
from enum import Enum
from glob import glob
-from pathlib import PurePath
+from pathlib import PurePath, Path
from typing import Callable, Dict, Iterator, List, Optional, Type, Tuple
import numpy as np
@@ -1516,30 +1517,107 @@ def do_action(context: CommandContext, params: dict):
class ReplaceVideo(EditCommand):
- topics = [UpdateTopic.video]
+ topics = [UpdateTopic.video, UpdateTopic.frame]
@staticmethod
- def do_action(context: CommandContext, params: dict):
- new_paths = params["new_video_paths"]
+ def do_action(context: CommandContext, params: dict) -> bool:
+
+ import_list = params["import_list"]
+
+ for import_item, video in import_list:
+ import_params = import_item["params"]
+
+ # TODO: Will need to create a new backend if import has different extension.
+ if (
+ Path(video.backend.filename).suffix
+ != Path(import_params["filename"]).suffix
+ ):
+ raise TypeError(
+ "Importing videos with different extensions is not supported."
+ )
+ video.backend.reset(**import_params)
+
+ # Remove frames in video past last frame index
+ last_vid_frame = video.last_frame_idx
+ lfs: List[LabeledFrame] = list(context.labels.get(video))
+ if lfs is not None:
+ lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame]
+ context.labels.remove_frames(lfs)
- for video, new_path in zip(context.labels.videos, new_paths):
- if new_path != video.backend.filename:
- video.backend.filename = new_path
- video.backend.reset()
+ # Update seekbar and video length through callbacks
+ context.state.emit("video")
@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
"""Shows gui for replacing videos in project."""
- paths = [video.backend.filename for video in context.labels.videos]
- okay = MissingFilesDialog(filenames=paths, replace=True).exec_()
+ def _get_truncation_message(truncation_messages, path, video):
+ reader = cv2.VideoCapture(path)
+ last_vid_frame = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
+ lfs: List[LabeledFrame] = list(context.labels.get(video))
+ if lfs is not None:
+ lfs.sort(key=lambda lf: lf.frame_idx)
+ last_lf_frame = lfs[-1].frame_idx
+ lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame]
+
+ # Message to warn users that labels will be removed if proceed
+ if last_lf_frame > last_vid_frame:
+ message = (
+ "
Warning: Replacing this video will "
+ f"remove {len(lfs)} labeled frames.
"
+ f"Current video: {Path(video.filename).name}"
+ f" (last label at frame {last_lf_frame})
"
+ f"Replacement video: {Path(path).name}"
+ f" ({last_vid_frame} frames)
"
+ )
+ # Assumes that a project won't import the same video multiple times
+ truncation_messages[path] = message
+
+ return truncation_messages
+
+ # Warn user: newly added labels will be discarded if project is not saved
+ if not context.state["filename"] or context.state["has_changes"]:
+ QMessageBox(
+ text=("You have unsaved changes. Please save before replacing videos.")
+ ).exec_()
+ return False
+ # Select the videos we want to swap
+ old_paths = [video.backend.filename for video in context.labels.videos]
+ paths = list(old_paths)
+ okay = MissingFilesDialog(filenames=paths, replace=True).exec_()
if not okay:
return False
- params["new_video_paths"] = paths
+ # Only return an import list for videos we swap
+ new_paths = [
+ (path, video_idx)
+ for video_idx, (path, old_path) in enumerate(zip(paths, old_paths))
+ if path != old_path
+ ]
- return True
+ new_paths = []
+ old_videos = dict()
+ all_videos = context.labels.videos
+ truncation_messages = dict()
+ for video_idx, (path, old_path) in enumerate(zip(paths, old_paths)):
+ if path != old_path:
+ new_paths.append(path)
+ old_videos[path] = all_videos[video_idx]
+ truncation_messages = _get_truncation_message(
+ truncation_messages, path, video=all_videos[video_idx]
+ )
+
+ import_list = ImportVideos().ask(
+ filenames=new_paths, messages=truncation_messages
+ )
+ # Remove videos that no longer correlate to filenames.
+ old_videos_to_replace = [
+ old_videos[imp["params"]["filename"]] for imp in import_list
+ ]
+ params["import_list"] = zip(import_list, old_videos_to_replace)
+
+ return len(import_list) > 0
class RemoveVideo(EditCommand):
diff --git a/sleap/gui/dialogs/importvideos.py b/sleap/gui/dialogs/importvideos.py
index 17d899002..127b46e5a 100644
--- a/sleap/gui/dialogs/importvideos.py
+++ b/sleap/gui/dialogs/importvideos.py
@@ -47,7 +47,11 @@ class ImportVideos:
def __init__(self):
self.result = []
- def ask(self, filenames: Optional[List[str]] = None):
+ def ask(
+ self,
+ filenames: Optional[List[str]] = None,
+ messages: Optional[Dict[str, str]] = None,
+ ):
"""Runs the import UI.
1. Show file selection dialog.
@@ -59,6 +63,8 @@ def ask(self, filenames: Optional[List[str]] = None):
Returns:
List with dict of the parameters for each file to import.
"""
+ messages = dict() if messages is None else messages
+
if filenames is None:
filenames, filter = FileDialog.openMultiple(
None,
@@ -66,10 +72,12 @@ def ask(self, filenames: Optional[List[str]] = None):
".", # initial path
"Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)",
)
+
if len(filenames) > 0:
- importer = ImportParamDialog(filenames)
+ importer = ImportParamDialog(filenames, messages)
importer.accepted.connect(lambda: importer.get_data(self.result))
importer.exec_()
+
return self.result
@classmethod
@@ -91,9 +99,12 @@ class ImportParamDialog(QDialog):
filenames (list): List of files we want to import.
"""
- def __init__(self, filenames: List[str], *args, **kwargs):
+ def __init__(
+ self, filenames: List[str], messages: Dict[str, str] = None, *args, **kwargs
+ ):
super(ImportParamDialog, self).__init__(*args, **kwargs)
+ messages = dict() if messages is None else messages
self.import_widgets = []
self.setWindowTitle("Video Import Options")
@@ -135,7 +146,6 @@ def __init__(self, filenames: List[str], *args, **kwargs):
outer_layout = QVBoxLayout()
scroll_widget = QScrollArea()
- # scroll_widget.setWidgetResizable(False)
scroll_widget.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
scroll_widget.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
@@ -152,7 +162,10 @@ def __init__(self, filenames: List[str], *args, **kwargs):
this_type = import_type
break
if this_type is not None:
- import_item_widget = ImportItemWidget(file_name, this_type)
+ message = messages[file_name] if file_name in messages else ""
+ import_item_widget = ImportItemWidget(
+ file_name, this_type, message=message
+ )
self.import_widgets.append(import_item_widget)
scroll_layout.addWidget(import_item_widget)
else:
@@ -197,6 +210,7 @@ def __init__(self, filenames: List[str], *args, **kwargs):
button_layout.addWidget(import_button)
outer_layout.addLayout(button_layout)
+ self.adjustSize()
self.setLayout(outer_layout)
@@ -220,14 +234,6 @@ def get_data(self, import_result=None):
import_result.append(import_item.get_data())
return import_result
- def boundingRect(self) -> QRectF:
- """Method required by Qt."""
- return QRectF()
-
- def paint(self, painter, option, widget=None):
- """Method required by Qt."""
- pass
-
def set_all_grayscale(self):
for import_item in self.import_widgets:
widget_elements = import_item.options_widget.widget_elements
@@ -269,7 +275,14 @@ class ImportItemWidget(QFrame):
import_type (dict): Data about user-selectable import parameters.
"""
- def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
+ def __init__(
+ self,
+ file_path: str,
+ import_type: Dict[str, Any],
+ message: str = "",
+ *args,
+ **kwargs,
+ ):
super(ImportItemWidget, self).__init__(*args, **kwargs)
self.file_path = file_path
@@ -287,6 +300,9 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
self.options_widget = ImportParamWidget(
parent=self, file_path=self.file_path, import_type=self.import_type
)
+
+ self.message_widget = MessageWidget(parent=self, message=message)
+
self.preview_widget = VideoPreviewWidget(parent=self)
self.preview_widget.setFixedSize(200, 200)
@@ -295,6 +311,7 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
)
inner_layout.addWidget(self.options_widget)
+ inner_layout.addWidget(self.message_widget)
inner_layout.addWidget(self.preview_widget)
import_item_layout.addLayout(inner_layout)
self.setLayout(import_item_layout)
@@ -379,7 +396,7 @@ class ImportParamWidget(QWidget):
changed = Signal()
- def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
+ def __init__(self, file_path: str, import_type: Dict[str, Any], *args, **kwargs):
super(ImportParamWidget, self).__init__(*args, **kwargs)
self.file_path = file_path
@@ -516,13 +533,17 @@ def _find_h5_datasets(self, data_path, data_object) -> list:
)
return options
- def boundingRect(self) -> QRectF:
- """Method required by Qt."""
- return QRectF()
- def paint(self, painter, option, widget=None):
- """Method required by Qt."""
- pass
+class MessageWidget(QWidget):
+ """Widget to show message."""
+
+ def __init__(self, message: str = str(), *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.message = QLabel(message)
+ self.message.setStyleSheet("color: red")
+ self.layout = QVBoxLayout()
+ self.layout.addWidget(self.message)
+ self.setLayout(self.layout)
class VideoPreviewWidget(QWidget):
@@ -585,34 +606,28 @@ def plot(self, idx=0):
# Display image
self.view.setImage(image)
- def boundingRect(self) -> QRectF:
- """Method required by Qt."""
- return QRectF()
-
- def paint(self, painter, option, widget=None):
- """Method required by Qt."""
- pass
+# if __name__ == "__main__":
-if __name__ == "__main__":
+# app = QApplication([])
- app = QApplication([])
+# # import_list = ImportVideos().ask()
- # import_list = ImportVideos().ask()
+# filenames = [
+# "tests/data/videos/centered_pair_small.mp4",
+# "tests/data/videos/small_robot.mp4",
+# ]
- filenames = [
- "tests/data/videos/centered_pair_small.mp4",
- "tests/data/videos/small_robot.mp4",
- ]
+# messages = {"tests/data/videos/small_robot.mp4": "Testing messages"}
- import_list = []
- importer = ImportParamDialog(filenames)
- importer.accepted.connect(lambda: importer.get_data(import_list))
- importer.exec_()
+# import_list = []
+# importer = ImportParamDialog(filenames, messages=messages)
+# importer.accepted.connect(lambda: importer.get_data(import_list))
+# importer.exec_()
- for import_item in import_list:
- vid = import_item["video_class"](**import_item["params"])
- print(
- "Imported video data: (%d, %d), %d f, %d c"
- % (vid.width, vid.height, vid.frames, vid.channels)
- )
+# for import_item in import_list:
+# vid = import_item["video_class"](**import_item["params"])
+# print(
+# "Imported video data: (%d, %d), %d f, %d c"
+# % (vid.width, vid.height, vid.frames, vid.channels)
+# )
diff --git a/sleap/gui/dialogs/missingfiles.py b/sleap/gui/dialogs/missingfiles.py
index dc8237f5c..eb78fd732 100644
--- a/sleap/gui/dialogs/missingfiles.py
+++ b/sleap/gui/dialogs/missingfiles.py
@@ -4,6 +4,7 @@
import os
+from pathlib import Path, PurePath
from typing import Callable, List
from PySide2 import QtWidgets, QtCore, QtGui
@@ -47,6 +48,7 @@ def __init__(
self.filenames = filenames
self.missing = missing
+ self.replace = replace
missing_count = sum(missing)
@@ -88,11 +90,22 @@ def locateFile(self, idx: int):
caption = f"Please locate {old_filename}..."
filters = [f"Missing file type (*{old_ext})", "Any File (*.*)"]
+ filters = [filters[0]] if self.replace else filters
new_filename, _ = FileDialog.open(
None, dir=None, caption=caption, filter=";;".join(filters)
)
- if new_filename:
+ path_new_filename = Path(new_filename)
+ paths = [str(PurePath(fn)) for fn in self.filenames]
+ if str(path_new_filename) in paths:
+ # Do not allow same video to be imported more than once.
+ QtWidgets.QMessageBox(
+ text=(
+ f"The file {path_new_filename.name} cannot be added to the "
+ "project multiple times."
+ )
+ ).exec_()
+ elif new_filename:
# Try using this change to find other missing files
self.setFilename(idx, new_filename)
diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py
index 414b68132..f88373cef 100644
--- a/tests/gui/test_commands.py
+++ b/tests/gui/test_commands.py
@@ -2,19 +2,23 @@
CommandContext,
ImportDeepLabCutFolder,
ExportAnalysisFile,
+ ReplaceVideo,
get_new_version_filename,
)
from sleap.io.dataset import Labels
from sleap.io.pathutils import fix_path_separator
from sleap.io.video import Video
from sleap.io.convert import default_analysis_filename
-from sleap.instance import Instance
+from sleap.instance import Instance, LabeledFrame
from tests.info.test_h5 import extract_meta_hdf5
from tests.io.test_video import assert_video_params
from pathlib import PurePath, Path
+from typing import List
+
import shutil
+import pytest
def test_delete_user_dialog(centered_pair_predictions):
@@ -254,3 +258,52 @@ def test_ToggleGrayscale(centered_pair_predictions: Labels):
# Toggle grayscale back to "grayscale"
context.toggleGrayscale()
assert_video_params(video=video, filename=filename, grayscale=grayscale)
+
+
+def test_ReplaceVideo(
+ centered_pair_predictions: Labels, small_robot_mp4_vid: Video, hdf5_vid: Video
+):
+ """Test functionality for ToggleGrayscale on mp4/avi video"""
+
+ def get_last_lf_in_video(labels, video):
+ lfs: List[LabeledFrame] = list(labels.get(videos[0]))
+ lfs.sort(key=lambda lf: lf.frame_idx)
+ return lfs[-1].frame_idx
+
+ def replace_video(
+ new_video: Video, videos_to_replace: List[Video], context: CommandContext
+ ):
+ # Video to be imported
+ new_video_filename = new_video.backend.filename
+
+ # Replace the video
+ import_item_list = [
+ {"params": {"filename": new_video_filename, "grayscale": True}}
+ ]
+ params = {"import_list": zip(import_item_list, videos_to_replace)}
+ ReplaceVideo.do_action(context=context, params=params)
+ return new_video_filename
+
+ # Labels and video to be replaced
+ labels = centered_pair_predictions
+ context = CommandContext.from_labels(labels)
+ videos = labels.videos
+ last_lf_frame = get_last_lf_in_video(labels, videos[0])
+
+ # Replace the video
+ new_video_filename = replace_video(small_robot_mp4_vid, videos, context)
+
+ # Ensure video backend was replaced
+ video = labels.video
+ assert len(labels.videos) == 1
+ assert video.backend.grayscale == True
+ assert video.backend.filename == new_video_filename
+
+ # Ensure labels were truncated (Original video was fully labeled)
+ new_last_lf_frame = get_last_lf_in_video(labels, video)
+ # Original video was fully labeled
+ assert new_last_lf_frame == labels.video.last_frame_idx
+
+ # Attempt to replace an mp4 with an hdf5 video
+ with pytest.raises(TypeError):
+ replace_video(hdf5_vid, labels.videos, context)