Skip to content

Commit

Permalink
reformatted linting with black
Browse files Browse the repository at this point in the history
  • Loading branch information
ericleonardis committed Dec 17, 2024
1 parent 9f71476 commit 117d5a3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 29 deletions.
30 changes: 21 additions & 9 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def exportVideoClip(self):
"""Exports a selected range of video frames and their corresponding labels."""
self.execute(ExportVideoClip)


# File Commands


Expand Down Expand Up @@ -3473,6 +3474,7 @@ def do_action(context: CommandContext, params: dict):
if rls is not None:
context.openWebsite(rls.url)


class ExportVideoClip(AppCommand):
@staticmethod
def do_action(context: CommandContext, params: dict):
Expand Down Expand Up @@ -3572,8 +3574,12 @@ def ask(context: CommandContext, params: dict) -> bool:
# Prompt the user with a simple Yes/No dialog for "Render Labels"
render_labels_dialog = QtWidgets.QMessageBox()
render_labels_dialog.setWindowTitle("Render Labels")
render_labels_dialog.setText("Do you want to include pose annotations in the video?")
render_labels_dialog.setStandardButtons(QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No)
render_labels_dialog.setText(
"Do you want to include pose annotations in the video?"
)
render_labels_dialog.setStandardButtons(
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No
)
render_labels_dialog.setDefaultButton(QtWidgets.QMessageBox.Yes)
render_labels = render_labels_dialog.exec_() == QtWidgets.QMessageBox.Yes

Expand Down Expand Up @@ -3629,7 +3635,8 @@ def ask(context: CommandContext, params: dict) -> bool:
params["marker_size"] = context.state.get("marker size", default=4)

return True



class ExportVideoClip(AppCommand):
@staticmethod
def do_action(context: CommandContext, params: dict):
Expand Down Expand Up @@ -3661,7 +3668,9 @@ def do_action(context: CommandContext, params: dict):
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
height, width = video.height, video.width
is_color = video.channels == 3
writer = cv2.VideoWriter(params["filename"], fourcc, params["fps"], (width, height), is_color)
writer = cv2.VideoWriter(
params["filename"], fourcc, params["fps"], (width, height), is_color
)

for frame_idx in params["frames"]:
frame = video.get_frame(frame_idx)
Expand All @@ -3671,14 +3680,18 @@ def do_action(context: CommandContext, params: dict):
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)

if frame.shape[:2] != (height, width):
raise ValueError(f"Frame size {frame.shape[:2]} does not match expected {height, width}")
raise ValueError(
f"Frame size {frame.shape[:2]} does not match expected {height, width}"
)

writer.write(frame)

writer.release()

# Create a new Video object for the output video
new_media_video = MediaVideo(filename=params["filename"], grayscale=video.channels == 1, bgr=True)
new_media_video = MediaVideo(
filename=params["filename"], grayscale=video.channels == 1, bgr=True
)
new_video = Video(backend=new_media_video)

# Step 1: Update all labeled frames to point to the new video
Expand Down Expand Up @@ -3744,16 +3757,15 @@ def ask(context: CommandContext, params: dict) -> bool:
params["filename"] = filename
params["fps"] = export_options["fps"]
params["open_when_done"] = export_options["open_when_done"]



# Access frame range
if context.state.get("has_frame_range"):
params["frames"] = range(*context.state["frame_range"])
else:
params["frames"] = range(context.state["video"].frames)

return True

def copy_to_clipboard(text: str):
"""Copy a string to the system clipboard.
Expand Down
4 changes: 3 additions & 1 deletion sleap/gui/dialogs/export_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sleap.gui.dialogs.formbuilder import FormBuilderModalDialog
from qtpy import QtWidgets


class ExportClipDialog(FormBuilderModalDialog):
def __init__(self):
from sleap.io.videowriter import VideoWriter
Expand All @@ -28,6 +29,7 @@ def __init__(self):

self.setWindowTitle("Export Clip Options")


class ExportClipAndLabelsDialog(FormBuilderModalDialog):
def __init__(self, video_fps=30):
from sleap.io.videowriter import VideoWriter
Expand Down Expand Up @@ -76,4 +78,4 @@ def get_results(self):
"fps": self.fps_input.value(),
"open_when_done": self.open_when_done.isChecked(),
}
return self._results
return self._results
47 changes: 28 additions & 19 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,11 +1058,15 @@ def test_ExportVideoClip_creates_files(tmpdir):
LabeledFrame(
video=video,
frame_idx=i,
instances=[Instance(skeleton=mock_skeleton, track=Track(name=f"track_{i}"))],
instances=[
Instance(skeleton=mock_skeleton, track=Track(name=f"track_{i}"))
],
)
for i in range(frame_count)
]
labels = Labels(labeled_frames=labeled_frames, videos=[video], skeletons=[mock_skeleton])
labels = Labels(
labeled_frames=labeled_frames, videos=[video], skeletons=[mock_skeleton]
)

# Set up CommandContext
context = CommandContext.from_labels(labels)
Expand Down Expand Up @@ -1096,7 +1100,10 @@ def test_ExportVideoClip_creates_files(tmpdir):

# Case 4: Assert that the exported labels match the expected number of frames
exported_labels = Labels.load_file(str(slp_path)) # Convert Path to string here
assert len(exported_labels.labeled_frames) == frame_count, "Incorrect number of frames in exported labels."
assert (
len(exported_labels.labeled_frames) == frame_count
), "Incorrect number of frames in exported labels."


def test_ExportVideoClip_frame_and_video_list_sizes(tmpdir):
"""Test that ExportVideoClip exports correct length labeled frames and video lists with a subset range."""
Expand Down Expand Up @@ -1127,11 +1134,15 @@ def test_ExportVideoClip_frame_and_video_list_sizes(tmpdir):
LabeledFrame(
video=video,
frame_idx=i,
instances=[Instance(skeleton=mock_skeleton, track=Track(name=f"track_{i}"))],
instances=[
Instance(skeleton=mock_skeleton, track=Track(name=f"track_{i}"))
],
)
for i in range(total_frames)
]
labels = Labels(labeled_frames=labeled_frames, videos=[video], skeletons=[mock_skeleton])
labels = Labels(
labeled_frames=labeled_frames, videos=[video], skeletons=[mock_skeleton]
)

# Set up CommandContext with subset frame range
context = CommandContext.from_labels(labels)
Expand All @@ -1151,32 +1162,30 @@ def test_ExportVideoClip_frame_and_video_list_sizes(tmpdir):
# Call ExportVideoClip
ExportVideoClip.do_action(context, params)


slp_path = export_path.with_suffix(".slp")
exported_labels = Labels.load_file(str(slp_path)) # Load exported labels

# Assertions

# Case 1: Check the number of labeled frames
assert len(exported_labels.labeled_frames) == subset_frame_count, (
f"Expected {subset_frame_count} labeled frames, but got {len(exported_labels.labeled_frames)}."
)
assert (
len(exported_labels.labeled_frames) == subset_frame_count
), f"Expected {subset_frame_count} labeled frames, but got {len(exported_labels.labeled_frames)}."

# Case 2: Check the number of videos
assert len(exported_labels.videos) == 1, (
f"Expected 1 video in exported labels, but got {len(exported_labels.videos)}."
)
assert (
len(exported_labels.videos) == 1
), f"Expected 1 video in exported labels, but got {len(exported_labels.videos)}."

# Case 3: Validate that the video in exported labels matches the filename
exported_video = exported_labels.videos[0]
assert exported_video.filename == str(export_path), (
f"Expected video filename to be '{export_path}', but got '{exported_video.filename}'."
)
assert exported_video.filename == str(
export_path
), f"Expected video filename to be '{export_path}', but got '{exported_video.filename}'."

# Case 4: Check the frame indices in labeled frames start from 0
expected_frame_indices = list(range(0, subset_frame_count))
actual_frame_indices = [lf.frame_idx for lf in exported_labels.labeled_frames]
assert actual_frame_indices == expected_frame_indices, (
f"Expected frame indices {expected_frame_indices}, but got {actual_frame_indices}."
)

assert (
actual_frame_indices == expected_frame_indices
), f"Expected frame indices {expected_frame_indices}, but got {actual_frame_indices}."

0 comments on commit 117d5a3

Please sign in to comment.