Skip to content

Commit

Permalink
feat: Use image color when add layer to napari (#1200)
Browse files Browse the repository at this point in the history
Use color associated with the channel in napari reader plugin.

<!-- Generated by sourcery-ai[bot]: start summary -->

## Summary by Sourcery

Add functionality to use channel-associated colors in napari layers by
adjusting hex color codes. Enhance the `_hex_to_rgb` function to support
additional hex code formats and update tests accordingly.

New Features:
- Integrate color adjustment for image layers in the napari reader
plugin, using the color associated with the channel.

Enhancements:
- Improve the `_hex_to_rgb` function to handle 4 and 8 character hex
codes, ensuring proper conversion and validation.

Tests:
- Expand tests for the `_hex_to_rgb` function to include cases for 4 and
8 character hex codes, ensuring correct functionality and error
handling.

<!-- Generated by sourcery-ai[bot]: end summary -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced `adjust_color` function to modify color strings based on
length, enhancing color handling.
- Updated `_hex_to_rgb` function to support additional hexadecimal color
formats.

- **Bug Fixes**
- Improved validation for hex color codes, ensuring robust error
handling for various input lengths.

- **Tests**
- Added new test cases for `_hex_to_rgb` to validate behavior with hex
codes including alpha channels.
- Enhanced tests for layer creation and colormap validation in image
processing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
Czaki authored Sep 27, 2024
1 parent 416dde7 commit ebe8226
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 21 deletions.
6 changes: 5 additions & 1 deletion package/PartSeg/common_gui/advanced_tabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,8 @@ def __init__(self, settings: PartSettings, parent=None):

def update_metadata(self):
self._dict_viewer.set_data(self.settings.image.metadata)
self.channel_info.setText(f"Channels: {self.settings.image.channel_names}")
text = ", ".join(
f"{name}: {color}"
for name, color in zip(self.settings.image.channel_names, self.settings.image.get_colors())
)
self.channel_info.setText(f"Channels with colors: {text}")
45 changes: 45 additions & 0 deletions package/PartSegCore/napari_plugins/loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,54 @@
import typing
from importlib.metadata import version

import numpy as np
from packaging.version import parse as parse_version

from PartSegCore.analysis import ProjectTuple
from PartSegCore.io_utils import LoadBase, WrongFileTypeException
from PartSegCore.mask.io_functions import MaskProjectTuple
from PartSegImage import Image


@typing.overload
def adjust_color(color: str) -> str: ...


@typing.overload
def adjust_color(color: typing.List[int]) -> typing.List[float]: ...


def adjust_color(color: typing.Union[str, typing.List[int]]) -> typing.Union[str, typing.List[float]]:
# as napari ignore alpha channel in color, and adding it to
# color cause that napari fails to detect that such colormap is already present
# in this function I remove alpha channel if it is present
if isinstance(color, str) and color.startswith("#"):
if len(color) == 9:
# case when color is in format #RRGGBBAA
return color[:7]
if len(color) == 5:
# case when color is in format #RGBA
return color[:4]
elif isinstance(color, list):
return [color[i] / 255 for i in range(3)]
# If not fit to an earlier case, return as is.
# Maybe napari will handle it
return color


if parse_version(version("napari")) >= parse_version("0.4.19a1"):

def add_color(image: Image, idx: int) -> dict:
return {
"colormap": adjust_color(image.get_colors()[idx]),
}

else:

def add_color(image: Image, idx: int) -> dict: # noqa: ARG001
# Do nothing, as napari is not able to pass hex color to image
# the image and idx are present to keep the same signature
return {}


def _image_to_layers(project_info, scale, translate):
Expand All @@ -27,6 +71,7 @@ def _image_to_layers(project_info, scale, translate):
"blending": "additive",
"translate": translate,
"metadata": project_info.image.metadata,
**add_color(project_info.image, i),
},
"image",
)
Expand Down
12 changes: 8 additions & 4 deletions package/PartSegImage/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def merge(self, image: Image, axis: str) -> Image:
self._channel_arrays + [self.reorder_axes(x, image.array_axis_order) for x in image._channel_arrays]
)
channel_names = self._merge_channel_names(self.channel_names, image.channel_names)
color_map = self.default_coloring + image.default_coloring
else:
index = self.array_axis_order.index(axis)
data = self._image_data_normalize(
Expand All @@ -402,8 +403,11 @@ def merge(self, image: Image, axis: str) -> Image:
]
)
channel_names = self.channel_names
color_map = self.default_coloring

return self.substitute(data=data, ranges=self.ranges + image.ranges, channel_names=channel_names)
return self.substitute(
data=data, ranges=self.ranges + image.ranges, channel_names=channel_names, default_coloring=color_map
)

@property
def channel_names(self) -> list[str]:
Expand Down Expand Up @@ -494,7 +498,7 @@ def substitute(
channel_names = self.channel_names if channel_names is None else channel_names

channel_info = [
ChannelInfoFull(name=name, color_map=color, contrast_limits=contrast_limits)
ChannelInfo(name=name, color_map=color, contrast_limits=contrast_limits)
for name, color, contrast_limits in zip_longest(channel_names, default_coloring, ranges)
]

Expand Down Expand Up @@ -973,9 +977,9 @@ def _hex_to_rgb(hex_code: str) -> tuple[int, int, int]:
"""
hex_code = hex_code.lstrip("#")

if len(hex_code) == 3:
if len(hex_code) in {3, 4}:
hex_code = "".join([c * 2 for c in hex_code])
elif len(hex_code) != 6:
elif len(hex_code) not in {6, 8}:
raise ValueError(f"Invalid hex code format: {hex_code}")

return int(hex_code[:2], 16), int(hex_code[2:4], 16), int(hex_code[4:6], 16)
Expand Down
21 changes: 17 additions & 4 deletions package/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from PartSegCore.roi_info import ROIInfo
from PartSegCore.segmentation.restartable_segmentation_algorithms import BorderRim, LowerThresholdAlgorithm
from PartSegCore.segmentation.segmentation_algorithm import ThresholdAlgorithm
from PartSegImage import Image
from PartSegImage import ChannelInfo, Image


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -69,20 +69,33 @@ def image2(image, tmp_path):
def image2d(tmp_path):
data = np.zeros([20, 20], dtype=np.uint8)
data[10:-1, 1:-1] = 20
return Image(data, (10**-3, 10**-3), axes_order="YX", file_path=str(tmp_path / "test.tiff"))
return Image(data, spacing=(10**-3, 10**-3), axes_order="YX", file_path=str(tmp_path / "test.tiff"))


@pytest.fixture
def stack_image():
data = np.zeros([20, 40, 40, 2], dtype=np.uint8)
data = np.zeros([20, 40, 40, 3], dtype=np.uint8)
for x, y in itertools.product([0, 20], repeat=2):
data[1:-1, x + 2 : x + 18, y + 2 : y + 18] = 100
for x, y in itertools.product([0, 20], repeat=2):
data[3:-3, x + 4 : x + 16, y + 4 : y + 16] = 120
for x, y in itertools.product([0, 20], repeat=2):
data[5:-5, x + 6 : x + 14, y + 6 : y + 14] = 140

return MaskProjectTuple("test_path", Image(data, (2, 1, 1), axes_order="ZYXC", file_path="test_path"))
return MaskProjectTuple(
"test_path",
Image(
data,
spacing=(2, 1, 1),
axes_order="ZYXC",
file_path="test_path",
channel_info=[
ChannelInfo(name="channel 1", color_map="#00FF00FF"),
ChannelInfo(name="channel 2", color_map="#00FF"),
ChannelInfo(name="channel 3", color_map="#FF0000"),
],
),
)


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions package/tests/test_PartSeg/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def test_base(self, image, analysis_segmentation2, tmp_path):
assert len(viewer.layers) == 2
settings.image = analysis_segmentation2.image
viewer.create_initial_layers(True, True, True, True)
assert len(viewer.layers) == 2
assert len(viewer.layers) == 3
settings.roi = analysis_segmentation2.roi_info.roi
viewer.create_initial_layers(True, True, True, True)
assert len(viewer.layers) == 3
assert len(viewer.layers) == 4
settings.mask = analysis_segmentation2.mask
viewer.create_initial_layers(True, True, True, True)
assert len(viewer.layers) == 4
assert len(viewer.layers) == 5
viewer.close()

def test_points(self, image, tmp_path, qtbot):
Expand Down
37 changes: 29 additions & 8 deletions package/tests/test_PartSegCore/test_napari_plugins.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# pylint: disable=no-self-use

import os
from importlib.metadata import version

import numpy as np
import pytest
from napari.layers import Image, Labels, Layer
from packaging.version import parse as parse_version

from PartSegCore.analysis import ProjectTuple
from PartSegCore.mask.io_functions import LoadROIFromTIFF
Expand All @@ -27,20 +29,39 @@
def test_project_to_layers_analysis(analysis_segmentation):
analysis_segmentation.roi_info.alternative["test"] = np.zeros(analysis_segmentation.image.shape, dtype=np.uint8)
res = project_to_layers(analysis_segmentation)
assert len(res) == 4
assert len(res) == 5
l1 = Layer.create(*res[0])
assert isinstance(l1, Image)
assert l1.name == "channel 1"
assert np.allclose(l1.scale[1:] / 1e9, analysis_segmentation.image.spacing)
l2 = Layer.create(*res[2])
assert isinstance(l2, Labels)
assert l2.name == "ROI"
assert np.allclose(l2.scale[1:] / 1e9, analysis_segmentation.image.spacing)
l3 = Layer.create(*res[3])
assert isinstance(l3, Labels)
assert l3.name == "test"
assert l3.name == "ROI"
assert np.allclose(l3.scale[1:] / 1e9, analysis_segmentation.image.spacing)
assert not l3.visible
l4 = Layer.create(*res[4])
assert isinstance(l4, Labels)
assert l4.name == "test"
assert np.allclose(l4.scale[1:] / 1e9, analysis_segmentation.image.spacing)
assert not l4.visible


@pytest.mark.skipif(
parse_version(version("napari")) < parse_version("0.4.19a16"), reason="not supported by old napari versions"
)
def test_passing_colormap(analysis_segmentation):
res = project_to_layers(analysis_segmentation)
l1 = Layer.create(*res[0])
assert isinstance(l1, Image)
assert l1.name == "channel 1"
assert l1.colormap.name == "green"
l2 = Layer.create(*res[1])
assert isinstance(l2, Image)
assert l2.name == "channel 2"
assert l2.colormap.name == "blue"
l2 = Layer.create(*res[2])
assert isinstance(l2, Image)
assert l2.name == "channel 3"
assert l2.colormap.name == "red"


def test_project_to_layers_roi():
Expand All @@ -55,7 +76,7 @@ def test_project_to_layers_roi():

def test_project_to_layers_mask(stack_segmentation1):
res = project_to_layers(stack_segmentation1)
assert len(res) == 3
assert len(res) == 4
assert res[0][2] == "image"


Expand Down
4 changes: 3 additions & 1 deletion package/tests/test_PartSegImage/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,11 +700,13 @@ def test_merge_channel_props_with_none(channel_name, default_coloring, ranges):

def test_hex_to_rgb():
assert _hex_to_rgb("#ff0000") == (255, 0, 0)
assert _hex_to_rgb("#ff0000ff") == (255, 0, 0)
assert _hex_to_rgb("#00FF00") == (0, 255, 0)
assert _hex_to_rgb("#b00") == (187, 0, 0)
assert _hex_to_rgb("#b00f") == (187, 0, 0)
assert _hex_to_rgb("#B00") == (187, 0, 0)
with pytest.raises(ValueError, match="Invalid hex code format"):
_hex_to_rgb("#b000")
_hex_to_rgb("#b0000")


def test_name_to_rgb():
Expand Down

0 comments on commit ebe8226

Please sign in to comment.