Skip to content

Commit

Permalink
remove unused arguments from custom subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Oct 11, 2024
1 parent 399c2bd commit da7119c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 54 deletions.
8 changes: 6 additions & 2 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ and save them to Zarr metadata.
viscy preprocess -c config.yaml
```

Or to preprocess all channels with the default sampling rate and 1 worker:
An example of the config file can be found [here](../examples/configs/preprocess_example.yml).

The are only a few arguments for this command,
so it may be desirable to run without having to edit a config file.
To preprocess all channels with the default sampling rate and 8 workers:

```sh
viscy preprocess --data_path /path/to/data.zarr
viscy preprocess --data_path=/path/to/data.zarr --num_workers=8
```

## Training
Expand Down
34 changes: 21 additions & 13 deletions examples/configs/export_example.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
# lightning.pytorch==2.0.4
seed_everything: true
# Export the FCMAE-pretrained VSCyto2D weights to ONNX format
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 32-true
callbacks: []
model:
architecture: null
model_config: {}
loss_function: null
lr: 0.001
schedule: Constant
log_num_samples: 8
test_cellpose_model_path: null
test_cellpose_diameter: null
test_evaluate_cellpose: false
export_path: null
ckpt_path: null
class_path: viscy.translation.engine.VSUNet
init_args:
architecture: fcmae
model_config:
in_channels: 1
out_channels: 2
encoder_blocks: [3, 3, 9, 3]
dims: [96, 192, 384, 768]
decoder_conv_blocks: 2
stem_kernel_size: [1, 2, 2]
in_stack_depth: 1
pretraining: False
# TODO: output path for the exported model
export_path: /hpc/mydata/ziwen.liu/ckpt.onnx
# TODO: path to the checkpoint file
# Download from:
# https://public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/VSCyto2D/epoch=399-step=23200.ckpt
ckpt_path: /hpc/websites/public.czbiohub.org/comp.micro/viscy/VS_models/VSCyto2D/VSCyto2D/epoch=399-step=23200.ckpt
format: onnx
11 changes: 6 additions & 5 deletions viscy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class VisCyCLI(LightningCLI):
@staticmethod
def subcommands() -> dict[str, set[str]]:
subcommands = LightningCLI.subcommands()
subcommand_base_args = {"model", "dataloaders", "datamodule"}
subcommand_base_args = {"model"}
subcommands["preprocess"] = subcommand_base_args
subcommands["export"] = subcommand_base_args
return subcommands
Expand Down Expand Up @@ -52,14 +52,15 @@ def main() -> None:
Set default random seed to 42.
"""
_setup_environment()
subclass_mode = bool("preprocess" not in sys.argv)
require_model = "preprocess" not in sys.argv
require_data = {"preprocess", "export"}.isdisjoint(sys.argv)
_ = VisCyCLI(
model_class=LightningModule,
datamodule_class=LightningDataModule if subclass_mode else None,
datamodule_class=LightningDataModule if require_data else None,
trainer_class=VisCyTrainer,
seed_everything_default=42,
subclass_mode_model=subclass_mode,
subclass_mode_data=subclass_mode,
subclass_mode_model=require_model,
subclass_mode_data=require_data,
parser_kwargs={
"description": "Computer vision models for single-cell phenotyping."
},
Expand Down
73 changes: 39 additions & 34 deletions viscy/trainer.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
import logging
from pathlib import Path
from typing import Literal, Sequence, Union
from typing import Literal

import torch
from iohub import open_ome_zarr
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized
from torch.onnx import OperatorExportTypes

from viscy.utils.meta_utils import generate_normalization_metadata

_logger = logging.getLogger("lightning.pytorch")


class VisCyTrainer(Trainer):
def preprocess(
self,
data_path: Path,
channel_names: Union[list[str], Literal[-1]] = -1,
channel_names: list[str] | Literal[-1] = -1,
num_workers: int = 1,
block_size: int = 32,
model: LightningModule = None,
datamodule: LightningDataModule = None,
dataloaders: Sequence = None,
model: LightningModule | None = None,
):
"""Compute dataset statistics before training or testing for normalization.
"""
Compute dataset statistics before training or testing for normalization.
:param Path data_path: Path to the HCS OME-Zarr dataset
:param Union[list[str], Literal[ channel_names: channel names,
defaults to -1 (all channels)
:param int num_workers: number of workers, defaults to 1
:param int block_size: sampling block size, defaults to 32
:param LightningModule model: place holder for model, ignored
:param LightningDataModule datamodule: place holder for datamodule, ignored
:param Sequence dataloaders: place holder for dataloaders, ignored
Parameters
----------
data_path : Path
Path to the HCS OME-Zarr dataset
channel_names : list[str] | Literal[-1], optional
Channel names to compute statistics for, by default -1
num_workers : int, optional
Number of CPU workers, by default 1
block_size : int, optional
Block size to subsample images, by default 32
model: LightningModule, optional
Ignored placeholder, by default None
"""
if model or dataloaders or datamodule:
logging.debug("Ignoring model and data configs during preprocessing.")
if model is not None:
_logger.warning("Ignoring model configuration during preprocessing.")
with open_ome_zarr(data_path, layout="hcs", mode="r") as dataset:
channel_indices = (
[dataset.channel_names.index(c) for c in channel_names]
Expand All @@ -51,29 +56,29 @@ def preprocess(
def export(
self,
model: LightningModule,
export_path: str,
ckpt_path: str,
format="onnx",
datamodule: LightningDataModule = None,
dataloaders: Sequence = None,
export_path: Path,
ckpt_path: Path,
format: str = "onnx",
):
"""Export the model for deployment (currently only ONNX is supported).
"""
Export the model for deployment (currently only ONNX is supported).
:param LightningModule model: module to export
:param str export_path: output file name
:param str ckpt_path: model checkpoint
:param str format: format (currently only ONNX is supported), defaults to "onnx"
:param LightningDataModule datamodule: placeholder for datamodule,
defaults to None
:param Sequence dataloaders: placeholder for dataloaders, defaults to None
Parameters
----------
model : LightningModule
Module to export.
export_path : Path
Output file name.
ckpt_path : Path
Model checkpoint path.
format : str, optional
Format (currently only ONNX is supported), by default "onnx".
"""
if dataloaders or datamodule:
logging.debug("Ignoring datamodule and dataloaders during export.")
if not format.lower() == "onnx":
raise NotImplementedError(f"Export format '{format}'")
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
model.load_state_dict(torch.load(ckpt_path)["state_dict"])
model.load_state_dict(torch.load(ckpt_path, weights_only=True)["state_dict"])
model.eval()
model.to_onnx(
export_path,
Expand All @@ -98,4 +103,4 @@ def export(
},
},
)
logging.info(f"ONNX exported at {export_path}")
_logger.info(f"ONNX exported at {export_path}")

0 comments on commit da7119c

Please sign in to comment.