Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add multiple otsu as threshold method with selection range of components #710

Merged
merged 13 commits into from
Oct 15, 2022
Merged
35 changes: 13 additions & 22 deletions package/PartSegCore/analysis/measurement_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@ def get_component_info(self, all_components: bool = False) -> Tuple[bool, bool]:
"""
if all_components and self.components_info.has_components():
return True, True
has_mask_components = any((x == PerComponent.Yes and y != AreaType.ROI for x, y in self._type_dict.values()))
has_mask_components = any(
(
x in {PerComponent.Yes, PerComponent.Per_Mask_component} and y != AreaType.ROI
for x, y in self._type_dict.values()
)
)
has_segmentation_components = any(
(x == PerComponent.Yes and y == AreaType.ROI for x, y in self._type_dict.values())
)
Expand Down Expand Up @@ -876,9 +881,7 @@ def calculate_property(area_array: np.ndarray, channel: np.ndarray, **_): # pyl
channel = channel.reshape(area_array.shape)
else: # pragma: no cover
raise ValueError(f"channel ({channel.shape}) and mask ({area_array.shape}) do not fit each other")
if np.any(area_array):
return np.sum(channel[area_array > 0])
return 0
return np.sum(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand Down Expand Up @@ -908,9 +911,7 @@ class MaximumPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError(f"channel ({channel.shape}) and mask ({area_array.shape}) do not fit each other")
if np.any(area_array):
return np.max(channel[area_array > 0])
return 0
return np.max(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -928,9 +929,7 @@ class MinimumPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.min(channel[area_array > 0])
return 0
return np.min(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -948,9 +947,7 @@ class MeanPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.mean(channel[area_array > 0])
return 0
return np.mean(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -968,9 +965,7 @@ class MedianPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.median(channel[area_array > 0])
return 0
return np.median(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand All @@ -991,9 +986,7 @@ class StandardDeviationOfPixelBrightness(MeasurementMethodBase):
def calculate_property(area_array, channel, **_): # pylint: disable=W0221
if area_array.shape != channel.shape: # pragma: no cover
raise ValueError("channel and mask do not fit each other")
if np.any(area_array):
return np.std(channel[area_array > 0])
return 0
return np.std(channel[area_array > 0]) if np.any(area_array) else 0

@classmethod
def get_units(cls, ndim):
Expand Down Expand Up @@ -1208,9 +1201,7 @@ def calculate_property(channel, area_array, **kwargs): # pylint: disable=W0221
if border_mask_array is None:
return None
final_mask = np.array((border_mask_array > 0) * (area_array > 0))
if np.any(final_mask):
return np.sum(channel[final_mask])
return 0
return np.sum(channel[final_mask]) if np.any(final_mask) else 0

@classmethod
def get_units(cls, ndim):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@
from .algorithm_base import ROIExtractionAlgorithm, ROIExtractionResult, SegmentationLimitException
from .mu_mid_point import BaseMuMid, MuMidSelection
from .noise_filtering import NoiseFilterSelection
from .threshold import BaseThreshold, DoubleThresholdSelection, ThresholdSelection
from .threshold import (
BaseThreshold,
DoubleThreshold,
DoubleThresholdParams,
DoubleThresholdSelection,
ManualThreshold,
SingleThresholdParams,
ThresholdSelection,
)
from .watershed import BaseWatershed, FlowMethodSelection, calculate_distances_array, get_neigh

REQUIRE_MASK_STR = "Need mask"
Expand Down Expand Up @@ -343,6 +351,7 @@ def get_name(cls):


class TwoThreshold(BaseModel):
# keep for backward compatibility
lower_threshold: float = Field(1000, ge=0, le=10**6)
upper_threshold: float = Field(10000, ge=0, le=10**6)

Expand All @@ -354,9 +363,26 @@ def _to_two_thresholds(dkt):
return dkt


@register_class(version="0.0.1", migrations=[("0.0.1", _to_two_thresholds)])
def _to_double_threshold(dkt):
dkt["threshold"] = DoubleThresholdSelection(
name=DoubleThreshold.get_name(),
values=DoubleThresholdParams(
core_threshold=ThresholdSelection(
name=ManualThreshold.get_name(),
values=SingleThresholdParams(threshold=dkt["threshold"].lower_threshold),
),
base_threshold=ThresholdSelection(
name=ManualThreshold.get_name(),
values=SingleThresholdParams(threshold=dkt["threshold"].upper_threshold),
),
),
)
return dkt


@register_class(version="0.0.2", migrations=[("0.0.1", _to_two_thresholds), ("0.0.2", _to_double_threshold)])
class RangeThresholdAlgorithmParameters(ThresholdBaseAlgorithmParameters):
threshold: TwoThreshold = Field(TwoThreshold(), position=2)
threshold: DoubleThresholdSelection = Field(DoubleThresholdSelection.get_default(), position=2)


class RangeThresholdAlgorithm(ThresholdBaseAlgorithm):
Expand All @@ -369,11 +395,12 @@ class RangeThresholdAlgorithm(ThresholdBaseAlgorithm):
__argument_class__ = RangeThresholdAlgorithmParameters

def _threshold(self, image, thr=None):
self.threshold_info = deepcopy(self.new_parameters.threshold)
return (
(image > self.new_parameters.threshold.lower_threshold)
* np.array(image < self.new_parameters.threshold.upper_threshold)
).astype(np.uint8)
if thr is None:
thr: BaseThreshold = DoubleThresholdSelection[self.new_parameters.threshold.name]
mask, thr_val = thr.calculate_mask(image, self.mask, self.new_parameters.threshold.values, operator.ge)
mask[mask == 2] = 0
self.threshold_info = thr_val
return mask

@classmethod
def get_name(cls):
Expand Down
80 changes: 77 additions & 3 deletions package/PartSegCore/segmentation/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class SimpleITKThresholdParams256(BaseModel):
bins: int = Field(128, title="Histogram bins", ge=8, le=2**16)


class MultipleOtsuThresholdParams(BaseModel):
components: int = Field(2, title="Number of Components", ge=2, lt=100)
border_component: int = Field(1, title="Border Component", ge=1, lt=100)
valley: bool = Field(True, title="Valley emphasis")
bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16)


class BaseThreshold(AlgorithmDescribeBase, ABC):
@classmethod
def calculate_mask(
Expand Down Expand Up @@ -67,7 +74,7 @@ class SitkThreshold(BaseThreshold, ABC):
def calculate_mask(
cls, data: np.ndarray, mask: typing.Optional[np.ndarray], arguments: SimpleITKThresholdParams128, operator
):
if mask is not None and mask.dtype != np.uint8:
if mask is not None and mask.dtype != np.uint8 and arguments.apply_mask:
mask = (mask > 0).astype(np.uint8)
ob, bg, th_op = (0, 1, np.min) if operator(1, 0) else (1, 0, np.max)
image_sitk = sitk.GetImageFromArray(data)
Expand Down Expand Up @@ -194,7 +201,7 @@ def calculate_threshold(*args, **kwargs):
return sitk.IntermodesThreshold(*args)
except RuntimeError as e:
if "Exceeded maximum iterations for histogram smoothing" in e.args[0]:
raise SegmentationLimitException(*e.args)
raise SegmentationLimitException(*e.args) from e
raise


Expand Down Expand Up @@ -223,7 +230,7 @@ def calculate_threshold(*args, **kwargs):
return sitk.KittlerIllingworthThreshold(*args)
except RuntimeError as e:
if "sigma2 <= 0" in e.args[0]:
raise SegmentationLimitException(*e.args)
raise SegmentationLimitException(*e.args) from e
raise


Expand All @@ -239,6 +246,35 @@ def calculate_threshold(*args, **kwargs):
return sitk.MomentsThreshold(*args)


class MultipleOtsuThreshold(BaseThreshold):
__argument_class__ = MultipleOtsuThresholdParams

@classmethod
def calculate_mask(
cls,
data: np.ndarray,
mask: typing.Optional[np.ndarray],
arguments: MultipleOtsuThresholdParams,
operator: typing.Callable[[object, object], bool],
):
cleaned_image_sitk = sitk.GetImageFromArray(data)
res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, arguments.components, 0, arguments.bins, arguments.valley)
res = sitk.GetArrayFromImage(res)
if operator(1, 0):
res = (res >= arguments.border_component).astype(np.uint8)
threshold = np.min(data[res > 0]) if np.any(res) else np.max(data)
else:
res = (res < arguments.border_component).astype(np.uint8)
threshold = np.max(data[res > 0]) if np.any(res) else np.min(data)
if mask is not None:
res[mask == 0] = 0
return res, threshold

@classmethod
def get_name(cls) -> str:
return "Multiple Otsu"


class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold):
pass

Expand All @@ -256,6 +292,7 @@ class ThresholdSelection(AlgorithmSelection, class_methods=["calculate_mask"], s
ThresholdSelection.register(KittlerIllingworthThreshold)
ThresholdSelection.register(MomentsThreshold)
ThresholdSelection.register(MaximumEntropyThreshold)
ThresholdSelection.register(MultipleOtsuThreshold)


class DoubleThresholdParams(BaseModel):
Expand Down Expand Up @@ -316,6 +353,42 @@ def calculate_mask(
return res, (thr1, thr2)


class MultipleOtsuDoubleThresholdParams(BaseModel):
components: int = Field(2, title="Number of Components", ge=2, lt=100)
lower_component: int = Field(1, title="Lower Component", ge=1, lt=100)
upper_component: int = Field(1, title="Upper Component", ge=1, lt=100)
valley: bool = Field(True, title="Valley emphasis")
bins: int = Field(128, title="Number of histogram bins", ge=8, le=2**16)


class MultipleOtsu(BaseThreshold):
__argument_class__ = MultipleOtsuDoubleThresholdParams

@classmethod
def get_name(cls):
return "Multiple Otsu"

@classmethod
def calculate_mask(
cls,
data: np.ndarray,
mask: typing.Optional[np.ndarray],
arguments: MultipleOtsuDoubleThresholdParams,
operator: typing.Callable[[object, object], bool],
):
cleaned_image_sitk = sitk.GetImageFromArray(data)
res = sitk.OtsuMultipleThresholds(cleaned_image_sitk, arguments.components, 0, arguments.bins, arguments.valley)
res = sitk.GetArrayFromImage(res)
map_component = np.zeros(arguments.components + 1, dtype=np.uint8)
map_component[: arguments.lower_component] = 0
map_component[arguments.lower_component : arguments.upper_component] = 1
map_component[arguments.upper_component :] = 2
res2 = map_component[res]
thr1 = data[res2 == 2].min() if np.any(res2 == 2) else data[res2 == 1].max()
thr2 = data[res2 == 1].min() if np.any(res2 == 1) else data.max()
return res2, (thr1, thr2)


class DoubleThresholdSelection(
AlgorithmSelection, class_methods=["calculate_mask"], suggested_base_class=BaseThreshold
):
Expand All @@ -324,6 +397,7 @@ class DoubleThresholdSelection(

DoubleThresholdSelection.register(DoubleThreshold)
DoubleThresholdSelection.register(DoubleOtsu)
DoubleThresholdSelection.register(MultipleOtsu)

double_threshold_dict = DoubleThresholdSelection.__register__
threshold_dict = ThresholdSelection.__register__
Expand Down
47 changes: 42 additions & 5 deletions package/tests/test_PartSegCore/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,21 @@ def test_simple(self):
alg = sa.RangeThresholdAlgorithm()
parameters = sa.RangeThresholdAlgorithm.__argument_class__(
threshold={
"lower_threshold": 45,
"upper_threshold": 60,
"name": "Base/Core",
"values": {
"base_threshold": {
"name": "Manual",
"values": {
"threshold": 45,
},
},
"core_threshold": {
"name": "Manual",
"values": {
"threshold": 60,
},
},
},
},
channel=0,
minimum_size=8000,
Expand All @@ -222,7 +235,7 @@ def test_simple(self):
assert result.parameters.values == parameters
assert result.parameters.algorithm == alg.get_name()

parameters.threshold.lower_threshold -= 6
parameters.threshold.values.base_threshold.values.threshold -= 6
alg.set_parameters(parameters)
result = alg.calculation_run(empty)
assert np.max(result.roi) == 1
Expand All @@ -235,8 +248,21 @@ def test_side_connection(self):
alg = sa.RangeThresholdAlgorithm()
parameters = sa.RangeThresholdAlgorithm.__argument_class__(
threshold={
"lower_threshold": 45,
"upper_threshold": 60,
"name": "Base/Core",
"values": {
"base_threshold": {
"name": "Manual",
"values": {
"threshold": 45,
},
},
"core_threshold": {
"name": "Manual",
"values": {
"threshold": 60,
},
},
},
},
channel=0,
minimum_size=8000,
Expand Down Expand Up @@ -825,3 +851,14 @@ def _repr(x):
count = [0]
algorithm_base.dict_repr({1: np.zeros(5), 2: {1: np.zeros(5)}})
assert count[0] == 2


def test_to_double_threshold():
data = {
"threshold": sa.TwoThreshold(
lower_threshold=50,
upper_threshold=100,
)
}
data = sa._to_double_threshold(data)
assert isinstance(data["threshold"], sa.DoubleThresholdSelection)