From ebe82266afdb2b5d80148b3d1483666914b73d60 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Fri, 27 Sep 2024 14:07:58 +0200 Subject: [PATCH] feat: Use image color when add layer to napari (#1200) Use color associated with the channel in napari reader plugin. ## Summary by Sourcery Add functionality to use channel-associated colors in napari layers by adjusting hex color codes. Enhance the `_hex_to_rgb` function to support additional hex code formats and update tests accordingly. New Features: - Integrate color adjustment for image layers in the napari reader plugin, using the color associated with the channel. Enhancements: - Improve the `_hex_to_rgb` function to handle 4 and 8 character hex codes, ensuring proper conversion and validation. Tests: - Expand tests for the `_hex_to_rgb` function to include cases for 4 and 8 character hex codes, ensuring correct functionality and error handling. ## Summary by CodeRabbit - **New Features** - Introduced `adjust_color` function to modify color strings based on length, enhancing color handling. - Updated `_hex_to_rgb` function to support additional hexadecimal color formats. - **Bug Fixes** - Improved validation for hex color codes, ensuring robust error handling for various input lengths. - **Tests** - Added new test cases for `_hex_to_rgb` to validate behavior with hex codes including alpha channels. - Enhanced tests for layer creation and colormap validation in image processing. --- package/PartSeg/common_gui/advanced_tabs.py | 6 ++- package/PartSegCore/napari_plugins/loader.py | 45 +++++++++++++++++++ package/PartSegImage/image.py | 12 +++-- package/tests/conftest.py | 21 +++++++-- package/tests/test_PartSeg/test_viewer.py | 6 +-- .../test_PartSegCore/test_napari_plugins.py | 37 +++++++++++---- package/tests/test_PartSegImage/test_image.py | 4 +- 7 files changed, 110 insertions(+), 21 deletions(-) diff --git a/package/PartSeg/common_gui/advanced_tabs.py b/package/PartSeg/common_gui/advanced_tabs.py index 2244f4df2..2556cf15e 100644 --- a/package/PartSeg/common_gui/advanced_tabs.py +++ b/package/PartSeg/common_gui/advanced_tabs.py @@ -275,4 +275,8 @@ def __init__(self, settings: PartSettings, parent=None): def update_metadata(self): self._dict_viewer.set_data(self.settings.image.metadata) - self.channel_info.setText(f"Channels: {self.settings.image.channel_names}") + text = ", ".join( + f"{name}: {color}" + for name, color in zip(self.settings.image.channel_names, self.settings.image.get_colors()) + ) + self.channel_info.setText(f"Channels with colors: {text}") diff --git a/package/PartSegCore/napari_plugins/loader.py b/package/PartSegCore/napari_plugins/loader.py index 05ef6c225..abd9ee629 100644 --- a/package/PartSegCore/napari_plugins/loader.py +++ b/package/PartSegCore/napari_plugins/loader.py @@ -1,10 +1,54 @@ import typing +from importlib.metadata import version import numpy as np +from packaging.version import parse as parse_version from PartSegCore.analysis import ProjectTuple from PartSegCore.io_utils import LoadBase, WrongFileTypeException from PartSegCore.mask.io_functions import MaskProjectTuple +from PartSegImage import Image + + +@typing.overload +def adjust_color(color: str) -> str: ... + + +@typing.overload +def adjust_color(color: typing.List[int]) -> typing.List[float]: ... + + +def adjust_color(color: typing.Union[str, typing.List[int]]) -> typing.Union[str, typing.List[float]]: + # as napari ignore alpha channel in color, and adding it to + # color cause that napari fails to detect that such colormap is already present + # in this function I remove alpha channel if it is present + if isinstance(color, str) and color.startswith("#"): + if len(color) == 9: + # case when color is in format #RRGGBBAA + return color[:7] + if len(color) == 5: + # case when color is in format #RGBA + return color[:4] + elif isinstance(color, list): + return [color[i] / 255 for i in range(3)] + # If not fit to an earlier case, return as is. + # Maybe napari will handle it + return color + + +if parse_version(version("napari")) >= parse_version("0.4.19a1"): + + def add_color(image: Image, idx: int) -> dict: + return { + "colormap": adjust_color(image.get_colors()[idx]), + } + +else: + + def add_color(image: Image, idx: int) -> dict: # noqa: ARG001 + # Do nothing, as napari is not able to pass hex color to image + # the image and idx are present to keep the same signature + return {} def _image_to_layers(project_info, scale, translate): @@ -27,6 +71,7 @@ def _image_to_layers(project_info, scale, translate): "blending": "additive", "translate": translate, "metadata": project_info.image.metadata, + **add_color(project_info.image, i), }, "image", ) diff --git a/package/PartSegImage/image.py b/package/PartSegImage/image.py index eadb3c847..c8f170c35 100644 --- a/package/PartSegImage/image.py +++ b/package/PartSegImage/image.py @@ -393,6 +393,7 @@ def merge(self, image: Image, axis: str) -> Image: self._channel_arrays + [self.reorder_axes(x, image.array_axis_order) for x in image._channel_arrays] ) channel_names = self._merge_channel_names(self.channel_names, image.channel_names) + color_map = self.default_coloring + image.default_coloring else: index = self.array_axis_order.index(axis) data = self._image_data_normalize( @@ -402,8 +403,11 @@ def merge(self, image: Image, axis: str) -> Image: ] ) channel_names = self.channel_names + color_map = self.default_coloring - return self.substitute(data=data, ranges=self.ranges + image.ranges, channel_names=channel_names) + return self.substitute( + data=data, ranges=self.ranges + image.ranges, channel_names=channel_names, default_coloring=color_map + ) @property def channel_names(self) -> list[str]: @@ -494,7 +498,7 @@ def substitute( channel_names = self.channel_names if channel_names is None else channel_names channel_info = [ - ChannelInfoFull(name=name, color_map=color, contrast_limits=contrast_limits) + ChannelInfo(name=name, color_map=color, contrast_limits=contrast_limits) for name, color, contrast_limits in zip_longest(channel_names, default_coloring, ranges) ] @@ -973,9 +977,9 @@ def _hex_to_rgb(hex_code: str) -> tuple[int, int, int]: """ hex_code = hex_code.lstrip("#") - if len(hex_code) == 3: + if len(hex_code) in {3, 4}: hex_code = "".join([c * 2 for c in hex_code]) - elif len(hex_code) != 6: + elif len(hex_code) not in {6, 8}: raise ValueError(f"Invalid hex code format: {hex_code}") return int(hex_code[:2], 16), int(hex_code[2:4], 16), int(hex_code[4:6], 16) diff --git a/package/tests/conftest.py b/package/tests/conftest.py index dc09f7630..68d484aea 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -25,7 +25,7 @@ from PartSegCore.roi_info import ROIInfo from PartSegCore.segmentation.restartable_segmentation_algorithms import BorderRim, LowerThresholdAlgorithm from PartSegCore.segmentation.segmentation_algorithm import ThresholdAlgorithm -from PartSegImage import Image +from PartSegImage import ChannelInfo, Image @pytest.fixture(scope="module") @@ -69,12 +69,12 @@ def image2(image, tmp_path): def image2d(tmp_path): data = np.zeros([20, 20], dtype=np.uint8) data[10:-1, 1:-1] = 20 - return Image(data, (10**-3, 10**-3), axes_order="YX", file_path=str(tmp_path / "test.tiff")) + return Image(data, spacing=(10**-3, 10**-3), axes_order="YX", file_path=str(tmp_path / "test.tiff")) @pytest.fixture def stack_image(): - data = np.zeros([20, 40, 40, 2], dtype=np.uint8) + data = np.zeros([20, 40, 40, 3], dtype=np.uint8) for x, y in itertools.product([0, 20], repeat=2): data[1:-1, x + 2 : x + 18, y + 2 : y + 18] = 100 for x, y in itertools.product([0, 20], repeat=2): @@ -82,7 +82,20 @@ def stack_image(): for x, y in itertools.product([0, 20], repeat=2): data[5:-5, x + 6 : x + 14, y + 6 : y + 14] = 140 - return MaskProjectTuple("test_path", Image(data, (2, 1, 1), axes_order="ZYXC", file_path="test_path")) + return MaskProjectTuple( + "test_path", + Image( + data, + spacing=(2, 1, 1), + axes_order="ZYXC", + file_path="test_path", + channel_info=[ + ChannelInfo(name="channel 1", color_map="#00FF00FF"), + ChannelInfo(name="channel 2", color_map="#00FF"), + ChannelInfo(name="channel 3", color_map="#FF0000"), + ], + ), + ) @pytest.fixture diff --git a/package/tests/test_PartSeg/test_viewer.py b/package/tests/test_PartSeg/test_viewer.py index a91e81d42..d75ea943e 100644 --- a/package/tests/test_PartSeg/test_viewer.py +++ b/package/tests/test_PartSeg/test_viewer.py @@ -75,13 +75,13 @@ def test_base(self, image, analysis_segmentation2, tmp_path): assert len(viewer.layers) == 2 settings.image = analysis_segmentation2.image viewer.create_initial_layers(True, True, True, True) - assert len(viewer.layers) == 2 + assert len(viewer.layers) == 3 settings.roi = analysis_segmentation2.roi_info.roi viewer.create_initial_layers(True, True, True, True) - assert len(viewer.layers) == 3 + assert len(viewer.layers) == 4 settings.mask = analysis_segmentation2.mask viewer.create_initial_layers(True, True, True, True) - assert len(viewer.layers) == 4 + assert len(viewer.layers) == 5 viewer.close() def test_points(self, image, tmp_path, qtbot): diff --git a/package/tests/test_PartSegCore/test_napari_plugins.py b/package/tests/test_PartSegCore/test_napari_plugins.py index 8ce07a263..a6e99403f 100644 --- a/package/tests/test_PartSegCore/test_napari_plugins.py +++ b/package/tests/test_PartSegCore/test_napari_plugins.py @@ -1,10 +1,12 @@ # pylint: disable=no-self-use import os +from importlib.metadata import version import numpy as np import pytest from napari.layers import Image, Labels, Layer +from packaging.version import parse as parse_version from PartSegCore.analysis import ProjectTuple from PartSegCore.mask.io_functions import LoadROIFromTIFF @@ -27,20 +29,39 @@ def test_project_to_layers_analysis(analysis_segmentation): analysis_segmentation.roi_info.alternative["test"] = np.zeros(analysis_segmentation.image.shape, dtype=np.uint8) res = project_to_layers(analysis_segmentation) - assert len(res) == 4 + assert len(res) == 5 l1 = Layer.create(*res[0]) assert isinstance(l1, Image) assert l1.name == "channel 1" assert np.allclose(l1.scale[1:] / 1e9, analysis_segmentation.image.spacing) - l2 = Layer.create(*res[2]) - assert isinstance(l2, Labels) - assert l2.name == "ROI" - assert np.allclose(l2.scale[1:] / 1e9, analysis_segmentation.image.spacing) l3 = Layer.create(*res[3]) assert isinstance(l3, Labels) - assert l3.name == "test" + assert l3.name == "ROI" assert np.allclose(l3.scale[1:] / 1e9, analysis_segmentation.image.spacing) - assert not l3.visible + l4 = Layer.create(*res[4]) + assert isinstance(l4, Labels) + assert l4.name == "test" + assert np.allclose(l4.scale[1:] / 1e9, analysis_segmentation.image.spacing) + assert not l4.visible + + +@pytest.mark.skipif( + parse_version(version("napari")) < parse_version("0.4.19a16"), reason="not supported by old napari versions" +) +def test_passing_colormap(analysis_segmentation): + res = project_to_layers(analysis_segmentation) + l1 = Layer.create(*res[0]) + assert isinstance(l1, Image) + assert l1.name == "channel 1" + assert l1.colormap.name == "green" + l2 = Layer.create(*res[1]) + assert isinstance(l2, Image) + assert l2.name == "channel 2" + assert l2.colormap.name == "blue" + l2 = Layer.create(*res[2]) + assert isinstance(l2, Image) + assert l2.name == "channel 3" + assert l2.colormap.name == "red" def test_project_to_layers_roi(): @@ -55,7 +76,7 @@ def test_project_to_layers_roi(): def test_project_to_layers_mask(stack_segmentation1): res = project_to_layers(stack_segmentation1) - assert len(res) == 3 + assert len(res) == 4 assert res[0][2] == "image" diff --git a/package/tests/test_PartSegImage/test_image.py b/package/tests/test_PartSegImage/test_image.py index 8ac148624..1b10bb50d 100644 --- a/package/tests/test_PartSegImage/test_image.py +++ b/package/tests/test_PartSegImage/test_image.py @@ -700,11 +700,13 @@ def test_merge_channel_props_with_none(channel_name, default_coloring, ranges): def test_hex_to_rgb(): assert _hex_to_rgb("#ff0000") == (255, 0, 0) + assert _hex_to_rgb("#ff0000ff") == (255, 0, 0) assert _hex_to_rgb("#00FF00") == (0, 255, 0) assert _hex_to_rgb("#b00") == (187, 0, 0) + assert _hex_to_rgb("#b00f") == (187, 0, 0) assert _hex_to_rgb("#B00") == (187, 0, 0) with pytest.raises(ValueError, match="Invalid hex code format"): - _hex_to_rgb("#b000") + _hex_to_rgb("#b0000") def test_name_to_rgb():