Skip to content

Commit

Permalink
feat: ✨ Add Cellpose model loading utilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 7, 2024
1 parent 4f59a8d commit a0af74f
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 3 deletions.
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ dev = [
'pdoc',
'pre-commit'
]
pretrained = [
'cellpose[gui]'
]

[project.urls]
homepage = "https://github.com/janelia-cellmap/cellmap-models"
repository = "https://github.com/janelia-cellmap/cellmap-models"

[tool.mypy]
exclude = ['setup*']
ignore_missing_imports = true
ignore_missing_imports = true

[project.scripts]
cellmap.add_cellpose = "cellmap_models.cellpose:add_model"
2 changes: 1 addition & 1 deletion src/cellmap_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

from .utils import download_url_to_file
from .pytorch import cosem
from .pytorch import cosem, cellpose
1 change: 1 addition & 0 deletions src/cellmap_models/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import cosem
from . import cellpose
36 changes: 36 additions & 0 deletions src/cellmap_models/pytorch/cellpose/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<!-- FILEPATH: /Users/rhoadesj/Repos/cellmap-models/src/cellmap_models/pytorch/cellpose/README.md -->
<h1 style="height: 2em;">Finetuned Cellpose Models <img src="https://www.cellpose.org/static/images/cellpose_transparent.png" alt="cellpose logo"></h1>

This directory contains finetuned scripts for downloading Cellpose models, particularly for use with the `cellpose` package. The models are trained on a variety of cell types from CellMap FIBSEM images, and can be used for segmentation of new data.

## Models

...

## Usage

Once you have chosen a model based on the descriptions above, you can download its weights from the `cellmap-models` repository and use them as described below:

If you would like to load a model for your own use, you can do the following:

```python
from cellmap_models.cellpose import load_model
model = load_model('<model_name>')
```

__If you would like to download and use a Cellpose model with the `cellpose` package or its GUI, do so by following the instructions below.__

First install the `cellpose` package:

```bash
conda activate cellmap
pip install cellpose[gui]
```

Then you can also download model weights from the `cellmap-models` repository and add them to your local `cellpose` model directory. For example, you can run the following commands:

```bash
cellmap.add_cellpose <model_name>
```

where `<model_name>` is the name of the model you would like to download, based on the descriptions above. For example, to download the `...
2 changes: 2 additions & 0 deletions src/cellmap_models/pytorch/cellpose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .add_model import add_model
from .load_model import load_model
29 changes: 29 additions & 0 deletions src/cellmap_models/pytorch/cellpose/add_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from . import models_dict
from cellpose.io import _add_model
from cellpose.models import MODEL_DIR
from cellpose.utils import download_url_to_file


def add_model(model_name: str):
"""Add model to cellpose
Args:
model_name (str): model name
"""
# download model to cellpose directory
if model_name not in models_dict:
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)
base_path = MODEL_DIR

if not (base_path / f"{model_name}.pth").exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(
models_dict[model_name], str(base_path / f"{model_name}.pth")
)
_add_model(str(base_path / f"{model_name}.pth"))
print(
f"Added model {model_name}. This will now be available in the cellpose model list."
)
return
36 changes: 36 additions & 0 deletions src/cellmap_models/pytorch/cellpose/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pathlib import Path
from . import models_dict
from cellmap_models.utils import download_url_to_file
import torch


def load_model(
model_name: str,
base_path: str = f"{Path(__file__).parent}/models",
device: str = "cuda",
):
"""Load model
Args:
model_name (str): model name
base_path (str, optional): base path to store Torchscript model. Defaults to "./models".
device (str, optional): device. Defaults to "cuda".
Returns:
model: model
"""
if model_name not in models_dict:
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)
if not (base_path / f"{model_name}.pth").exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(
models_dict[model_name], str(base_path / f"{model_name}.pth")
)
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
print("CUDA not available. Using CPU.")
model = torch.jit.load(str(base_path / f"{model_name}.pth"), device)
model.eval()
return model
6 changes: 6 additions & 0 deletions src/cellmap_models/pytorch/cosem/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def get_param_dict(model_params):


def load_model(checkpoint_path):
"""
Load a model from a checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file.
"""
if not Path(checkpoint_path).exists():
checkpoint_path = Path(Path(__file__).parent / checkpoint_path)
model_params = SourceFileLoader(
Expand Down
3 changes: 2 additions & 1 deletion src/cellmap_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@


def download_url_to_file(url, dst, progress=True):
r"""Download object at the given URL to a local path.
# Originally from CellPose
"""Download object at the given URL to a local path.
Thanks to torch, slightly modified
Args:
url (string): URL of the object to download
Expand Down

0 comments on commit a0af74f

Please sign in to comment.