Skip to content

Commit

Permalink
map frame_to_indedx and find valid_frame_indices for labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ericleonardis committed Dec 18, 2024
1 parent e972d3f commit 39cc496
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,53 +3485,70 @@ class ExportClipVideo(AppCommand):
@staticmethod
def do_action(context: CommandContext, params: dict):
"""
Exports a pruned video clip and labels to a specified file based on selected frame range.
Executes the export action for a video clip and its labels.
Args:
context (CommandContext): Contains state information like video and labels.
params (dict): Parameters including filename, fps, and open_when_done.
context (CommandContext): The application context containing the current state.
params (dict): Parameters for the export, including:
- 'filename' (str): The path to save the exported video.
- 'fps' (int): Frames per second for the exported video.
- 'open_when_done' (bool): Whether to open the video file after exporting.
Raises:
ValueError: If the frame range is invalid or no clip is selected.
RuntimeError: If there are issues with video writing or saving labels.
"""
# Extract video and labels from context
video = context.state["video"]
labels = context.state["labels"]

# Ensure frame range is set; default to all frames if None
frame_range = context.state.get("frame_range", (0, video.frames))

# Validate frame range
if frame_range[0] < 0 or frame_range[1] > video.frames:
raise ValueError(f"Frame range {frame_range} is outside video bounds [0, {video.frames}]")

# Check if clip is selected, raise error if no clip selected
if frame_range == (0, video.frames) or frame_range == (0, 1) or frame_range[0] == frame_range[1]:
raise ValueError("No valid clip frame range selected! Please select a valid frame range using shift + click in the GUI.")

# Map frame indices to the actual labeled frame objects
frame_to_index = {lf.frame_idx: idx for idx, lf in enumerate(labels.labeled_frames) if lf.video == video}
valid_frame_indices = [frame for frame in range(*frame_range) if frame in frame_to_index]

if not valid_frame_indices:
raise ValueError("No valid labeled frames found in the selected frame range.")

# Extract only the selected frames into a new Labels object
pruned_labels = labels.extract(
inds=range(*frame_range),
copy=True, # Ensures a deep copy of the extracted labels
inds=[frame_to_index[frame] for frame in valid_frame_indices],
copy=True # Ensures a deep copy of the extracted labels
)

# Remap frame indices in pruned_labels to start from 0
# Remap frame indices in pruned_labels to start from 0 while maintaining spacing
frame_offset = frame_range[0]
for labeled_frame in pruned_labels.labeled_frames:
labeled_frame.frame_idx -= frame_range[0]
labeled_frame.frame_idx -= frame_offset

# Initialize VideoWriter
height, width = video.height, video.width
fps = params["fps"]
writer = VideoWriter.safe_builder(params["filename"], height, width, fps)

# Conditionally show progress bar
# Conditionally show progress bar
show_progress = os.getenv("PYTEST_RUNNING") != "1"
if show_progress:
app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([])
progress = QtWidgets.QProgressDialog("Exporting video...", "Cancel", 0, len(range(*frame_range)))
progress = QtWidgets.QProgressDialog("Exporting video...", "Cancel", 0, len(valid_frame_indices))
progress.setWindowModality(QtCore.Qt.WindowModal)
progress.setValue(0)
else:
progress = None # Progress bar disabled during tests

# Write frames to the video
try:
for idx, frame_idx in enumerate(range(*frame_range)):
for idx, frame_idx in enumerate(valid_frame_indices):
if show_progress and progress.wasCanceled():
writer.close()
os.remove(params["filename"])
Expand All @@ -3556,7 +3573,7 @@ def do_action(context: CommandContext, params: dict):
finally:
writer.close()
if show_progress:
progress.setValue(frame_range[1] - frame_range[0]) # Complete progress
progress.setValue(len(valid_frame_indices)) # Complete progress

# Create a new Video object for the output video
new_media_video = MediaVideo(
Expand Down Expand Up @@ -3617,13 +3634,14 @@ def ask(context: CommandContext, params: dict) -> bool:
params["open_when_done"] = export_options["open_when_done"]

# Access frame range
if context.state.get("has_frame_range"):
if context.state.get("frame_range"):
params["frames"] = range(*context.state["frame_range"])
else:
params["frames"] = range(context.state["video"].frames)

return True


class ExportClipPkg(AppCommand):
@staticmethod
def do_action(context: CommandContext, params: dict):
Expand All @@ -3643,12 +3661,19 @@ def do_action(context: CommandContext, params: dict):

# Check if clip is selected, raise error if no clip selected
if frame_range == (0, video.frames) or frame_range == (0, 1) or frame_range[0] == frame_range[1]:
raise ValueError("No valid clip frame range selected! Please select a valid frame range using shift + click in the GUI.")
raise ValueError("No valid clip frame range selected! Please select a valid frame range using shift + click in the GUI.")

# Map frame indices to the actual labeled frame objects
frame_to_index = {lf.frame_idx: idx for idx, lf in enumerate(labels.labeled_frames) if lf.video == video}
valid_frame_indices = [frame for frame in range(*frame_range) if frame in frame_to_index]

if not valid_frame_indices:
raise ValueError("No valid labeled frames found in the selected frame range.")

# Extract only the selected frames into a new Labels object
pruned_labels = labels.extract(
inds=range(*frame_range),
copy=True, # Ensures a deep copy of the extracted labels
inds=[frame_to_index[frame] for frame in valid_frame_indices],
copy=True # Ensures a deep copy of the extracted labels
)

# Remap frame indices in pruned_labels to start from 0
Expand Down

0 comments on commit 39cc496

Please sign in to comment.