diff --git a/ultrack/__init__.py b/ultrack/__init__.py index 70e79ab..4c9b848 100644 --- a/ultrack/__init__.py +++ b/ultrack/__init__.py @@ -7,6 +7,13 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +# Cellpose and ultrack had conflicts due to torch/cuda leading to Segmentation Fault +# importing Cellpose first avoids the issue, https://github.com/royerlab/ultrack/issues/108 +try: + from cellpose.models import Cellpose # noqa: F401 +except (ImportError, ModuleNotFoundError): + pass + # ignoring small float32/64 zero flush warning warnings.filterwarnings("ignore", message="The value of the smallest subnormal for") diff --git a/ultrack/imgproc/segmentation.py b/ultrack/imgproc/segmentation.py index 033ea9c..e94d9a4 100644 --- a/ultrack/imgproc/segmentation.py +++ b/ultrack/imgproc/segmentation.py @@ -1,5 +1,6 @@ +import functools import logging -from typing import Optional +from typing import Callable, Optional import edt import numpy as np @@ -211,20 +212,33 @@ def inverted_edt( return dist +def _maybe_wrap(wrapper_name: str) -> Callable: + """Wraps function with cellpose model method if cellpose is available.""" + try: + from cellpose.models import CellposeModel as _Cellpose + except ImportError: + return lambda x: x + + return functools.wraps(getattr(_Cellpose, wrapper_name)) + + class Cellpose: + @_maybe_wrap("__init__") def __init__(self, **kwargs) -> None: - """See cellpose.models.Cellpose documentation for details.""" - from cellpose.models import CellposeModel as _Cellpose + try: + from cellpose.models import CellposeModel as _Cellpose + except ImportError as e: + raise ImportError( + "Cellpose not found, please install it." + "See for instructions https://github.com/MouseLand/cellpose" + ) from e if "pretrained_model" not in kwargs and "model_type" not in kwargs: kwargs["model_type"] = "cyto" self.model = _Cellpose(**kwargs) + @_maybe_wrap("eval") def __call__(self, image: ArrayLike, **kwargs) -> np.ndarray: - """ - Predicts image labels. - See cellpose.models.Cellpose.eval documentation for details. - """ labels, _, _ = self.model.eval(image, **kwargs) return labels