diff --git a/pyproject.toml b/pyproject.toml index 34030a5..28771bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,4 +45,4 @@ exclude = ['setup*'] ignore_missing_imports = true [project.scripts] -cellmap.add_cellpose = "cellmap_models.cellpose:add_model" \ No newline at end of file +"cellmap.add_cellpose" = "cellmap_models.pytorch.cellpose:add_model" diff --git a/src/cellmap_models/__pycache__/__init__.cpython-310.pyc b/src/cellmap_models/__pycache__/__init__.cpython-310.pyc index 9c1eceb..e649df2 100644 Binary files a/src/cellmap_models/__pycache__/__init__.cpython-310.pyc and b/src/cellmap_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/cellmap_models/__pycache__/utils.cpython-310.pyc b/src/cellmap_models/__pycache__/utils.cpython-310.pyc index 15284a9..b067da1 100644 Binary files a/src/cellmap_models/__pycache__/utils.cpython-310.pyc and b/src/cellmap_models/__pycache__/utils.cpython-310.pyc differ diff --git a/src/cellmap_models/pytorch/cellpose/__init__.py b/src/cellmap_models/pytorch/cellpose/__init__.py index c8faf0a..6e0bc43 100644 --- a/src/cellmap_models/pytorch/cellpose/__init__.py +++ b/src/cellmap_models/pytorch/cellpose/__init__.py @@ -1,5 +1,6 @@ from .add_model import add_model from .load_model import load_model +from .get_model import get_model models_dict = { "jrc_mus-epididymis-1_nuc_cp": "https://github.com/janelia-cellmap/cellmap-models/releases/download/2024.03.08/jrc_mus-epididymis-1_nuc_cp", diff --git a/src/cellmap_models/pytorch/cellpose/add_model.py b/src/cellmap_models/pytorch/cellpose/add_model.py index 7246750..d63d68c 100644 --- a/src/cellmap_models/pytorch/cellpose/add_model.py +++ b/src/cellmap_models/pytorch/cellpose/add_model.py @@ -1,27 +1,20 @@ -from . import models_dict -from cellpose.io import _add_model +import sys +from typing import Optional +from cellpose.io import add_model as _add_model from cellpose.models import MODEL_DIR -from cellpose.utils import download_url_to_file +from .get_model import get_model -def add_model(model_name: str): +def add_model(model_name: Optional[str] = None): """Add model to cellpose Args: model_name (str): model name """ - # download model to cellpose directory - if model_name not in models_dict: - raise ValueError( - f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." - ) + if model_name is None: + model_name = sys.argv[1] base_path = MODEL_DIR - - if not (base_path / f"{model_name}.pth").exists(): - print(f"Downloading {model_name} from {models_dict[model_name]}") - download_url_to_file( - models_dict[model_name], str(base_path / f"{model_name}.pth") - ) + get_model(model_name, base_path) _add_model(str(base_path / f"{model_name}.pth")) print( f"Added model {model_name}. This will now be available in the cellpose model list." diff --git a/src/cellmap_models/pytorch/cellpose/get_model.py b/src/cellmap_models/pytorch/cellpose/get_model.py new file mode 100644 index 0000000..0122dca --- /dev/null +++ b/src/cellmap_models/pytorch/cellpose/get_model.py @@ -0,0 +1,29 @@ +from pathlib import Path +from cellpose.utils import download_url_to_file + + +def get_model( + model_name: str, + base_path: str = f"{Path(__file__).parent}/models", +): + """Add model to cellpose + + Args: + model_name (str): model name + base_path (str, optional): base path to store Torchscript model. Defaults to "./models". + """ + from . import models_dict + + # download model to cellpose directory + if model_name not in models_dict: + raise ValueError( + f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." + ) + + if not (base_path / f"{model_name}.pth").exists(): + print(f"Downloading {model_name} from {models_dict[model_name]}") + download_url_to_file( + models_dict[model_name], str(base_path / f"{model_name}.pth") + ) + print("Downloaded model {model_name} to {base_path}.") + return diff --git a/src/cellmap_models/pytorch/cellpose/load_model.py b/src/cellmap_models/pytorch/cellpose/load_model.py index 9b11bc6..e00c6ba 100644 --- a/src/cellmap_models/pytorch/cellpose/load_model.py +++ b/src/cellmap_models/pytorch/cellpose/load_model.py @@ -1,7 +1,6 @@ from pathlib import Path -from . import models_dict -from cellmap_models.utils import download_url_to_file import torch +from .get_model import get_model def load_model( @@ -19,15 +18,8 @@ def load_model( Returns: model: model """ - if model_name not in models_dict: - raise ValueError( - f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." - ) - if not (base_path / f"{model_name}.pth").exists(): - print(f"Downloading {model_name} from {models_dict[model_name]}") - download_url_to_file( - models_dict[model_name], str(base_path / f"{model_name}.pth") - ) + + get_model(model_name, base_path) if device == "cuda" and not torch.cuda.is_available(): device = "cpu" print("CUDA not available. Using CPU.")