Skip to content

Commit

Permalink
feat: Add multiple otsu as threshold method with selection range of c…
Browse files Browse the repository at this point in the history
…omponents (#710)

Co-authored-by: Sourcery AI <>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Grzegorz Bokota <[email protected]>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 15, 2022
1 parent 54958d1 commit cb16b0e
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 38 deletions.
35 changes: 13 additions & 22 deletions package/PartSegCore/analysis/measurement_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@ def get_component_info(self, all_components: bool = False) -> Tuple[bool, bool]:
"""
if all_components and self.components_info.has_components():
return True, True
has_mask_components = any((x == PerComponent.Yes and y != AreaType.ROI for x, y in self._type_dict.values()))
has_mask_components = any(
(
x in {PerComponent.Yes, PerComponent.Per_Mask_component} and y != AreaType.ROI
for x, y in self._type_dict.values()
)
)
has_segmentation_components = any(
(x == PerComponent.Yes and y == AreaType.ROI for x, y in self._type_dict.values())
)
Expand Down Expand Up @@ -876,9 +881,7 @@ def calculate_property(area_array: np.ndarray, channel: np.ndarray, **_): # pyl
channel = channel.reshape(area_array.shape)
else: # pragma: no cover
raise ValueError(f"channel ({channel.shape}) and mask ({area_array.shape}) do not fit each other")
if np.any(area_array):
return np.sum(channel[area_array > 0])
return 0
return np.sum(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand Down Expand Up @@ -908,9 +911,7 @@ class MaximumPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError(f"channel ({channel.shape}) and mask ({area_array.shape}) do not fit each other")
if np.any(area_array):
return np.max(channel[area_array > 0])
return 0
return np.max(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -928,9 +929,7 @@ class MinimumPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.min(channel[area_array > 0])
return 0
return np.min(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -948,9 +947,7 @@ class MeanPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.mean(channel[area_array > 0])
return 0
return np.mean(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -968,9 +965,7 @@ class MedianPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.median(channel[area_array > 0])
return 0
return np.median(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -991,9 +986,7 @@ class StandardDeviationOfPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.std(channel[area_array > 0])
return 0
return np.std(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand Down Expand Up @@ -1208,9 +1201,7 @@ def calculate_property(channel, area_array, **kwargs): # pylint: disable=W0221
if border_mask_array is None:
return None
final_mask = np.array((border_mask_array > 0) * (area_array > 0))
if np.any(final_mask):
return np.sum(channel[final_mask])
return 0
return np.sum(channel[final_mask]) if np.any(final_mask) else 0

@classmethod
def get_units(cls, ndim):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
from .algorithm_base import ROIExtractionAlgorithm, ROIExtractionResult, SegmentationLimitException
from .mu_mid_point import BaseMuMid, MuMidSelection
from .noise_filtering import NoiseFilterSelection
from .threshold import BaseThreshold, DoubleThresholdSelection, ThresholdSelection
from .threshold import (
BaseThreshold,
DoubleThreshold,
DoubleThresholdParams,
DoubleThresholdSelection,
ManualThreshold,
SingleThresholdParams,
ThresholdSelection,
)
from .watershed import BaseWatershed, FlowMethodSelection, calculate_distances_array, get_neigh

REQUIRE_MASK_STR = "Need mask"
Expand Down Expand Up @@ -343,6 +351,7 @@ def get_name(cls):


class TwoThreshold(BaseModel):
# keep for backward compatibility
lower_threshold: float = Field(1000, ge=0, le=10**6)
upper_threshold: float = Field(10000, ge=0, le=10**6)

Expand All @@ -354,9 +363,26 @@ def _to_two_thresholds(dkt):
return dkt


@register_class(version="0.0.1", migrations=[("0.0.1", _to_two_thresholds)])
def _to_double_threshold(dkt):
dkt["threshold"] = DoubleThresholdSelection(
name=DoubleThreshold.get_name(),
values=DoubleThresholdParams(
core_threshold=ThresholdSelection(
name=ManualThreshold.get_name(),
values=SingleThresholdParams(threshold=dkt["threshold"].lower_threshold),
),
base_threshold=ThresholdSelection(
name=ManualThreshold.get_name(),
values=SingleThresholdParams(threshold=dkt["threshold"].upper_threshold),
),
),
)
return dkt


@register_class(version="0.0.2", migrations=[("0.0.1", _to_two_thresholds), ("0.0.2", _to_double_threshold)])
class RangeThresholdAlgorithmParameters(ThresholdBaseAlgorithmParameters):
threshold: TwoThreshold = Field(TwoThreshold(), position=2)
threshold: DoubleThresholdSelection = Field(DoubleThresholdSelection.get_default(), position=2)


class RangeThresholdAlgorithm(ThresholdBaseAlgorithm):
Expand All @@ -369,11 +395,12 @@ class RangeThresholdAlgorithm(ThresholdBaseAlgorithm):
__argument_class__ = RangeThresholdAlgorithmParameters

def _threshold(self, image, thr=None):
self.threshold_info = deepcopy(self.new_parameters.threshold)
return (
(image > self.new_parameters.threshold.lower_threshold)
* np.array(image < self.new_parameters.threshold.upper_threshold)
).astype(np.uint8)
if thr is None:
thr: BaseThreshold = DoubleThresholdSelection[self.new_parameters.threshold.name]
mask, thr_val = thr.calculate_mask(image, self.mask, self.new_parameters.threshold.values, operator.ge)
mask[mask == 2] = 0
self.threshold_info = thr_val
return mask

@classmethod
def get_name(cls):
Expand Down
80 changes: 77 additions & 3 deletions package/PartSegCore/segmentation/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class SimpleITKThresholdParams256(BaseModel):
bins: int = Field(128, title="Histogram bins", ge=8, le=2**16)


class MultipleOtsuThresholdParams(BaseModel):
components: int = Field(2, title="Number of Components", ge=2, lt=100)
border_component: int = Field(1, title="Border Component", ge=1, lt=100)
valley: bool = Field(True, title="Valley emphasis")
bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16)


class BaseThreshold(AlgorithmDescribeBase, ABC):
@classmethod
def calculate_mask(
Expand Down Expand Up @@ -67,7 +74,7 @@ class SitkThreshold(BaseThreshold, ABC):
def calculate_mask(
cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: SimpleITKThresholdParams128, operator
):
if mask is not None and mask.dtype != np.uint8:
if mask is not None and mask.dtype != np.uint8 and arguments.apply_mask:
mask = (mask > 0).astype(np.uint8)
ob, bg, th_op = (0, 1, np.min) if operator(1, 0) else (1, 0, np.max)
image_sitk = sitk.GetImageFromArray(data)
Expand Down Expand Up @@ -194,7 +201,7 @@ def calculate_threshold(*args, **kwargs):
return sitk.IntermodesThreshold(*args)
except RuntimeError as e:
if "Exceeded maximum iterations for histogram smoothing" in e.args[0]:
raise SegmentationLimitException(*e.args)
raise SegmentationLimitException(*e.args) from e
raise


Expand Down Expand Up @@ -223,7 +230,7 @@ def calculate_threshold(*args, **kwargs):
return sitk.KittlerIllingworthThreshold(*args)
except RuntimeError as e:
if "sigma2 <= 0" in e.args[0]:
raise SegmentationLimitException(*e.args)
raise SegmentationLimitException(*e.args) from e
raise


Expand All @@ -239,6 +246,35 @@ def calculate_threshold(*args, **kwargs):
return sitk.MomentsThreshold(*args)


class MultipleOtsuThreshold(BaseThreshold):
__argument_class__ = MultipleOtsuThresholdParams

@classmethod
def calculate_mask(
cls,
data: np.ndarray,
mask: typing.Optional[np.ndarray],
arguments: MultipleOtsuThresholdParams,
operator: typing.Callable[[object, object], bool],
):
cleaned_image_sitk = sitk.GetImageFromArray(data)
res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, arguments.components, 0, arguments.bins, arguments.valley)
res = sitk.GetArrayFromImage(res)
if operator(1, 0):
res = (res >= arguments.border_component).astype(np.uint8)
threshold = np.min(data[res > 0]) if np.any(res) else np.max(data)
else:
res = (res < arguments.border_component).astype(np.uint8)
threshold = np.max(data[res > 0]) if np.any(res) else np.min(data)
if mask is not None:
res[mask == 0] = 0
return res, threshold

@classmethod
def get_name(cls) -> str:
return "Multiple Otsu"


class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold):
pass

Expand All @@ -256,6 +292,7 @@ class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], s
ThresholdSelection.register(KittlerIllingworthThreshold)
ThresholdSelection.register(MomentsThreshold)
ThresholdSelection.register(MaximumEntropyThreshold)
ThresholdSelection.register(MultipleOtsuThreshold)


class DoubleThresholdParams(BaseModel):
Expand Down Expand Up @@ -316,6 +353,42 @@ def calculate_mask(
return res, (thr1, thr2)


class MultipleOtsuDoubleThresholdParams(BaseModel):
components: int = Field(2, title="Number of Components", ge=2, lt=100)
lower_component: int = Field(1, title="Lower Component", ge=1, lt=100)
upper_component: int = Field(1, title="Upper Component", ge=1, lt=100)
valley: bool = Field(True, title="Valley emphasis")
bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16)


class MultipleOtsu(BaseThreshold):
__argument_class__ = MultipleOtsuDoubleThresholdParams

@classmethod
def get_name(cls):
return "Multiple Otsu"

@classmethod
def calculate_mask(
cls,
data: np.ndarray,
mask: typing.Optional[np.ndarray],
arguments: MultipleOtsuDoubleThresholdParams,
operator: typing.Callable[[object, object], bool],
):
cleaned_image_sitk = sitk.GetImageFromArray(data)
res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, arguments.components, 0, arguments.bins, arguments.valley)
res = sitk.GetArrayFromImage(res)
map_component = np.zeros(arguments.components + 1, dtype=np.uint8)
map_component[: arguments.lower_component] = 0
map_component[arguments.lower_component : arguments.upper_component] = 1
map_component[arguments.upper_component :] = 2
res2 = map_component[res]
thr1 = data[res2 == 2].min() if np.any(res2 == 2) else data[res2 == 1].max()
thr2 = data[res2 == 1].min() if np.any(res2 == 1) else data.max()
return res2, (thr1, thr2)


class DoubleThresholdSelection(
AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold
):
Expand All @@ -324,6 +397,7 @@ class DoubleThresholdSelection(

DoubleThresholdSelection.register(DoubleThreshold)
DoubleThresholdSelection.register(DoubleOtsu)
DoubleThresholdSelection.register(MultipleOtsu)

double_threshold_dict = DoubleThresholdSelection.__register__
threshold_dict = ThresholdSelection.__register__
Expand Down
47 changes: 42 additions & 5 deletions package/tests/test_PartSegCore/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,21 @@ def test_simple(self):
alg = sa.RangeThresholdAlgorithm()
parameters = sa.RangeThresholdAlgorithm.__argument_class__(
threshold={
"lower_threshold": 45,
"upper_threshold": 60,
"name": "Base/Core",
"values": {
"base_threshold": {
"name": "Manual",
"values": {
"threshold": 45,
},
},
"core_threshold": {
"name": "Manual",
"values": {
"threshold": 60,
},
},
},
},
channel=0,
minimum_size=8000,
Expand All @@ -222,7 +235,7 @@ def test_simple(self):
assert result.parameters.values == parameters
assert result.parameters.algorithm == alg.get_name()

parameters.threshold.lower_threshold -= 6
parameters.threshold.values.base_threshold.values.threshold -= 6
alg.set_parameters(parameters)
result = alg.calculation_run(empty)
assert np.max(result.roi) == 1
Expand All @@ -235,8 +248,21 @@ def test_side_connection(self):
alg = sa.RangeThresholdAlgorithm()
parameters = sa.RangeThresholdAlgorithm.__argument_class__(
threshold={
"lower_threshold": 45,
"upper_threshold": 60,
"name": "Base/Core",
"values": {
"base_threshold": {
"name": "Manual",
"values": {
"threshold": 45,
},
},
"core_threshold": {
"name": "Manual",
"values": {
"threshold": 60,
},
},
},
},
channel=0,
minimum_size=8000,
Expand Down Expand Up @@ -825,3 +851,14 @@ def _repr(x):
count = [0]
algorithm_base.dict_repr({1: np.zeros(5), 2: {1: np.zeros(5)}})
assert count[0] == 2


def test_to_double_threshold():
data = {
"threshold": sa.TwoThreshold(
lower_threshold=50,
upper_threshold=100,
)
}
data = sa._to_double_threshold(data)
assert isinstance(data["threshold"], sa.DoubleThresholdSelection)

0 comments on commit cb16b0e

Please sign in to comment.