Skip to content

Commit

Permalink
updated comments
Browse files Browse the repository at this point in the history
  • Loading branch information
7174Andy committed May 23, 2024
1 parent 2b8131e commit ff0aedc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 83 deletions.
4 changes: 0 additions & 4 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,10 +1162,6 @@ def _update_gui_state(self):
and has_selected_session
)

# Update color predicted
if self.state["distinctly_color"] == "instance_groups":
self.state["color predicted"] = False

# Update overlays
self.overlays["track_labels"].visible = (
control_key_down and has_selected_instance
Expand Down
60 changes: 10 additions & 50 deletions sleap/gui/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,12 @@ class ColorManager:
labels: The :class:`Labels` dataset which contains the tracks for
which we want colors.
palette: String with the color palette name to use.
session: The :class:`RecordingSession` object which contains the
instance groups for which we want colors.
frame_idx: The index of the frame in the session for which we want
instance group colors.
"""

def __init__(
self,
labels: Labels = None,
palette: str = "standard",
session: RecordingSession = None,
frame_idx: int = None,
):
self.labels = labels

Expand Down Expand Up @@ -77,9 +70,6 @@ def __init__(
self.medium_pen_width = self.thick_pen_width // 2
self.default_pen_width = max(1, self.thick_pen_width // 4)

self.session = session
self.frame_idx = frame_idx

@property
def labels(self):
"""Gets or sets labels dataset for which we are coloring tracks."""
Expand Down Expand Up @@ -123,14 +113,6 @@ def tracks(self) -> Iterable[Track]:
return self.labels.tracks
return []

@property
def instance_groups(self) -> Iterable[InstanceGroup]:
"""Gets instance groups for project."""
frame_group = self.session.frame_groups.get(self.frame_idx, None)
if frame_group is not None:
return frame_group.instance_groups
return []

def set_palette(self, palette: Union[Text, Iterable[ColorTupleStringType]]):
"""Functional alias for palette property setter."""
self.palette = palette
Expand Down Expand Up @@ -202,28 +184,6 @@ def get_track_color(self, track: Union[Track, int]) -> ColorTupleType:

return self.get_color_by_idx(track_idx)

def get_instance_group_color(
self, instance_group: Union[InstanceGroup, int]
) -> ColorTupleType:
"""Returns the color to use for a given instance group.
Args:
instance_group: `InstanceGroup` object
Returns:
(r, g, b)-tuple
"""
instance_group_idx = instance_group
if isinstance(instance_group, InstanceGroup):
instance_group_idx = (
self.instance_groups.index(instance_group)
if instance_group in self.instance_groups
else None
)
if instance_group_idx is None:
return (0, 0, 0)

return self.get_color_by_idx(instance_group_idx)

@classmethod
def is_sequence(cls, item) -> bool:
"""Returns whether item is a tuple or list."""
Expand Down Expand Up @@ -295,6 +255,16 @@ def get_item_color(
if not parent_skeleton and hasattr(parent_instance, "skeleton"):
parent_skeleton = parent_instance.skeleton

is_predicted = False
if parent_instance and self.is_predicted(parent_instance):
is_predicted = True

if is_predicted and not self.color_predicted:
if isinstance(item, Node):
return self.uncolored_prediction_color

return (128, 128, 128)

if parent_frame_idx is None and parent_instance:
parent_frame = parent_instance.frame
if parent_frame:
Expand All @@ -318,16 +288,6 @@ def get_item_color(
instance_group_idx = frame_group.instance_groups.index(instance_group)
return self.get_color_by_idx(instance_group_idx)

is_predicted = False
if parent_instance and self.is_predicted(parent_instance):
is_predicted = True

if is_predicted and not self.color_predicted:
if isinstance(item, Node):
return self.uncolored_prediction_color

return (128, 128, 128)

if self.distinctly_color == "instances" or hasattr(item, "track"):
track = None
if hasattr(item, "track"):
Expand Down
29 changes: 0 additions & 29 deletions tests/gui/test_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,3 @@ def test_track_color(centered_pair_predictions):
assert color_manager.get_item_color(
inst_0.skeleton.edges[edge_idx], inst_0
) == color_manager.get_color_by_idx(edge_idx)


def test_instance_group_color(multiview_min_session_frame_groups):
labels = multiview_min_session_frame_groups
session = labels.sessions[0]
frame_idx = 0
frame_group = session.frame_groups[frame_idx]
instance_group = frame_group.instance_groups[0]

# Test instance group colors
color_manager = ColorManager(labels=labels, session=session, frame_idx=frame_idx)

# Test instance groups are stored correctly
assert color_manager.labels == labels
assert color_manager.session == session
assert color_manager.frame_idx == frame_idx
assert color_manager.instance_groups == frame_group.instance_groups

# Test instance group colors
assert color_manager.get_instance_group_color(
instance_group
) == color_manager.get_color_by_idx(0)
assert list(color_manager.get_instance_group_color(instance_group)) != [0, 0, 0]

# Test whether if the instance group color is the same as the instance color
instance = instance_group.instances[0]
assert color_manager.get_item_color(
item=instance, parent_session=session, parent_frame_idx=0
) == color_manager.get_instance_group_color(instance_group)

0 comments on commit ff0aedc

Please sign in to comment.