diff --git a/README.md b/README.md index 36c320174..6f5b52f20 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,44 @@ model.train() await model.train(run_async=True) ``` +Most models work by passing data paths in the data config. For training or predicting on datasets that are already in memory, you can pass the data directly to the model. Note that this use case is primarily for programmatic use (e.g. in a workflow or a jupyter notebook), not through the normal CLI. An experiment showing a possible config setup for this use case is demonstrated with the [im2im/segmentation_array](configs/experiment/im2im/segmentation_array.yaml) experiment. For training, data must be passed as a dictionary with keys "train" and "val" containing lists of dictionaries with keys corresponding to the data config. + +```python +from cyto_dl.api import CytoDLModel +import numpy as np + +model = CytoDLModel() +model.load_default_experiment("segmentation_array", output_dir="./output") +model.print_config() + +# create CZYX dummy data +data = { + "train": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}], + "val": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}], +} +model.train(data=data) +``` + +For predicting, data must be passed as a list of numpy arrays. The resulting predictions will be processed in a dictionary with one key for each task head in the model config and corresponding values in BC(Z)YX order. + +```python +from cyto_dl.api import CytoDLModel +import numpy as np +from cyto_dl.utils import extract_array_predictions + +model = CytoDLModel() +model.load_default_experiment( + "segmentation_array", output_dir="./output", overrides=["data=im2im/numpy_dataloader_predict"] +) +model.print_config() + +# create CZYX dummy data +data = [np.random.rand(1, 32, 64, 64), np.random.rand(1, 32, 64, 64)] + +_, _, output = model.predict(data=data) +preds = extract_array_predictions(output) +``` + Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/) ```bash diff --git a/configs/data/im2im/numpy_dataloader_predict.yaml b/configs/data/im2im/numpy_dataloader_predict.yaml new file mode 100644 index 000000000..084c490d4 --- /dev/null +++ b/configs/data/im2im/numpy_dataloader_predict.yaml @@ -0,0 +1,16 @@ +_target_: cyto_dl.datamodules.array.make_array_dataloader +data: +num_workers: 1 +batch_size: 1 +source_key: ${source_col} +transforms: + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + - _target_: cyto_dl.image.transforms.clip.Clipd + keys: + - ${source_col} + - _target_: monai.transforms.NormalizeIntensityd + channel_wise: true + keys: + - ${source_col} diff --git a/configs/data/im2im/numpy_dataloader_train.yaml b/configs/data/im2im/numpy_dataloader_train.yaml new file mode 100644 index 000000000..074474a2a --- /dev/null +++ b/configs/data/im2im/numpy_dataloader_train.yaml @@ -0,0 +1,79 @@ +_aux: + patch_shape: + _scales_dict: + - - ${target_col} + - [1] + - - ${source_col} + - [1] + +train_dataloaders: + _target_: cyto_dl.datamodules.array.make_array_dataloader + data: + num_workers: 0 + batch_size: 1 + source_key: ${source_col} + transforms: + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + - ${target_col} + - _target_: cyto_dl.image.transforms.clip.Clipd + keys: ${source_col} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: true + - _target_: monai.transforms.ThresholdIntensityd + keys: ${target_col} + threshold: 0.1 + above: False + cval: 1 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: + - ${source_col} + - ${target_col} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + - _target_: monai.transforms.RandHistogramShiftd + prob: 0.1 + keys: ${source_col} + num_control_points: [90, 500] + + - _target_: monai.transforms.RandStdShiftIntensityd + prob: 0.1 + keys: ${source_col} + factors: 0.1 + + - _target_: monai.transforms.RandAdjustContrastd + prob: 0.1 + keys: ${source_col} + gamma: [0.9, 1.5] + +val_dataloaders: + _target_: cyto_dl.datamodules.array.make_array_dataloader + data: + num_workers: 0 + batch_size: 1 + source_key: ${source_col} + transforms: + - _target_: monai.transforms.ToTensord + keys: + - ${source_col} + - ${target_col} + - _target_: cyto_dl.image.transforms.clip.Clipd + keys: ${source_col} + - _target_: monai.transforms.NormalizeIntensityd + keys: ${source_col} + channel_wise: true + - _target_: monai.transforms.ThresholdIntensityd + keys: ${target_col} + threshold: 0.1 + above: False + cval: 1 + - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + keys: + - ${source_col} + - ${target_col} + patch_shape: ${data._aux.patch_shape} + patch_per_image: 1 + scales_dict: ${kv_to_dict:${data._aux._scales_dict}} diff --git a/configs/experiment/im2im/segmentation_array.yaml b/configs/experiment/im2im/segmentation_array.yaml new file mode 100644 index 000000000..d5564879a --- /dev/null +++ b/configs/experiment/im2im/segmentation_array.yaml @@ -0,0 +1,40 @@ +# @package _global_ +# to execute this experiment run: +# python train.py experiment=example +defaults: + - override /data: im2im/numpy_dataloader_train.yaml + - override /model: im2im/segmentation.yaml + - override /callbacks: default.yaml + - override /trainer: gpu.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["dev"] +seed: 12345 + +experiment_name: YOUR_EXP_NAME +run_name: YOUR_RUN_NAME +source_col: raw +target_col: seg +spatial_dims: 3 +raw_im_channels: 1 + +trainer: + max_epochs: 100 + +data: + _aux: + # 2D + # patch_shape: [64, 64] + # 3D + patch_shape: [16, 32, 32] + +callbacks: + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/model/im2im/gan.yaml b/configs/model/im2im/gan.yaml index 4b6afa7f8..0bdd2fccc 100644 --- a/configs/model/im2im/gan.yaml +++ b/configs/model/im2im/gan.yaml @@ -54,7 +54,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/configs/model/im2im/gan_superres.yaml b/configs/model/im2im/gan_superres.yaml index a7d454e58..463e4a370 100644 --- a/configs/model/im2im/gan_superres.yaml +++ b/configs/model/im2im/gan_superres.yaml @@ -54,7 +54,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/configs/model/im2im/instance_seg.yaml b/configs/model/im2im/instance_seg.yaml index 48ad27a2e..751c9ff3e 100644 --- a/configs/model/im2im/instance_seg.yaml +++ b/configs/model/im2im/instance_seg.yaml @@ -34,7 +34,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/configs/model/im2im/labelfree.yaml b/configs/model/im2im/labelfree.yaml index 92655456f..3852678e2 100644 --- a/configs/model/im2im/labelfree.yaml +++ b/configs/model/im2im/labelfree.yaml @@ -35,7 +35,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/configs/model/im2im/mae.yaml b/configs/model/im2im/mae.yaml index ab1fcac39..615f8c3c4 100644 --- a/configs/model/im2im/mae.yaml +++ b/configs/model/im2im/mae.yaml @@ -36,7 +36,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/configs/model/im2im/segmentation.yaml b/configs/model/im2im/segmentation.yaml index a3172567e..3a62021d7 100644 --- a/configs/model/im2im/segmentation.yaml +++ b/configs/model/im2im/segmentation.yaml @@ -35,8 +35,9 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 mode: "gaussian" + progress: True _aux: _tasks: diff --git a/configs/model/im2im/segmentation_superres.yaml b/configs/model/im2im/segmentation_superres.yaml index 6bdb3414f..24f697594 100644 --- a/configs/model/im2im/segmentation_superres.yaml +++ b/configs/model/im2im/segmentation_superres.yaml @@ -35,7 +35,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/configs/model/im2im/vit_segmentation_decoder.yaml b/configs/model/im2im/vit_segmentation_decoder.yaml index 62ce7e064..6c1ad137b 100644 --- a/configs/model/im2im/vit_segmentation_decoder.yaml +++ b/configs/model/im2im/vit_segmentation_decoder.yaml @@ -37,7 +37,8 @@ lr_scheduler: inference_args: sw_batch_size: 1 roi_size: ${data._aux.patch_shape} - overlap: 0.25 + overlap: 0 + progress: True mode: "gaussian" _aux: diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py index 24e9648fa..5f72d1ce6 100644 --- a/cyto_dl/api/data.py +++ b/cyto_dl/api/data.py @@ -11,6 +11,7 @@ class ExperimentType(Enum): LABEL_FREE = "labelfree" SEGMENTATION_PLUGIN = "segmentation_plugin" SEGMENTATION = "segmentation" + SEGMENTATION_ARRAY = "segmentation_array" class HardwareType(Enum): diff --git a/cyto_dl/api/model.py b/cyto_dl/api/model.py index 21e7d9a1c..294b43cf8 100644 --- a/cyto_dl/api/model.py +++ b/cyto_dl/api/model.py @@ -100,16 +100,16 @@ async def _train_async(self): async def _predict_async(self): return evaluate(self.cfg) - def train(self, run_async=False): + def train(self, run_async=False, data=None): if self.cfg is None: raise ValueError("Configuration must be loaded before training!") if run_async: return self._train_async() - return train_model(self.cfg) + return train_model(self.cfg, data) - def predict(self, run_async=False): + def predict(self, run_async=False, data=None): if self.cfg is None: raise ValueError("Configuration must be loaded before predicting!") if run_async: return self._predict_async() - return evaluate(self.cfg) + return evaluate(self.cfg, data) diff --git a/cyto_dl/datamodules/array.py b/cyto_dl/datamodules/array.py new file mode 100644 index 000000000..3de4f9b52 --- /dev/null +++ b/cyto_dl/datamodules/array.py @@ -0,0 +1,48 @@ +from typing import Callable, Dict, List, Sequence, Union + +import numpy as np +from monai.data import DataLoader, Dataset +from monai.transforms import Compose +from omegaconf import ListConfig, OmegaConf + + +def make_array_dataloader( + data: Union[np.ndarray, List[np.ndarray], List[Dict[str, np.ndarray]]], + transforms: Union[Sequence[Callable], Callable], + source_key: str = "input", + **dataloader_kwargs, +): + """Create a dataloader from a an array dataset. + + Parameters + ---------- + data: Union[np.ndarray, List[np.ndarray], List[Dict[str, np.ndarray]], + If a numpy array (prediction only), the dataloader will be created with a single source_key. + If a list each element must be a numpy array (for prediction) or a dictionary containing numpy array values (for training). + + transforms: Union[Sequence[Callable], Callable], + Transforms to apply to each sample + + dataloader_kwargs: + Additional keyword arguments are passed to the + torch.utils.data.DataLoader class when instantiating it (aside from + `shuffle` which is only used for the train dataloader). + Among these args are `num_workers`, `batch_size`, `shuffle`, etc. + See the PyTorch docs for more info on these args: + https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader + """ + if isinstance(transforms, (list, tuple, ListConfig)): + transforms = Compose(transforms) + data = OmegaConf.to_object(data) + if isinstance(data, (list, tuple, ListConfig)): + data = [{source_key: d} if isinstance(d, np.ndarray) else d for d in data] + elif isinstance(data, np.ndarray): + data = [{source_key: data}] + else: + raise ValueError( + f"Invalid data type: {type(data)}. Data must be a numpy array or list of numpy arrays." + ) + + dataset = Dataset(data, transform=transforms) + + return DataLoader(dataset, **dataloader_kwargs) diff --git a/cyto_dl/eval.py b/cyto_dl/eval.py index 1fe1e92c2..7961a151b 100644 --- a/cyto_dl/eval.py +++ b/cyto_dl/eval.py @@ -6,8 +6,9 @@ import hydra from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers import Logger +from monai.data import DataLoader as MonaiDataLoader from omegaconf import DictConfig, ListConfig, OmegaConf -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader as TorchDataLoader from cyto_dl import utils @@ -19,7 +20,7 @@ @utils.task_wrapper -def evaluate(cfg: DictConfig) -> Tuple[dict, dict, dict]: +def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]: """Evaluates given checkpoint on a datamodule testset. This method is wrapped in optional @task_wrapper decorator which applies extra utilities @@ -40,9 +41,8 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict, dict]: # remove aux section after resolving and before instantiating utils.remove_aux_key(cfg) - - data = hydra.utils.instantiate(cfg.data) - if not isinstance(data, (LightningDataModule, DataLoader)): + data = utils.create_dataloader(cfg.data, data) + if not isinstance(data, (LightningDataModule, TorchDataLoader, MonaiDataLoader)): if isinstance(data, MutableMapping) and not data.dataloaders: raise ValueError( "If the `data` config for eval/prediction is a dict it must have a " diff --git a/cyto_dl/train.py b/cyto_dl/train.py index e00eaaa24..0c3e03c50 100644 --- a/cyto_dl/train.py +++ b/cyto_dl/train.py @@ -23,7 +23,7 @@ @utils.task_wrapper -def train(cfg: DictConfig) -> Tuple[dict, dict]: +def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. @@ -57,7 +57,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: utils.remove_aux_key(cfg) log.info(f"Instantiating data <{cfg.data.get('_target_', cfg.data)}>") - data = hydra.utils.instantiate(cfg.data) + data = utils.create_dataloader(cfg.data, data) if not isinstance(data, LightningDataModule): if not isinstance(data, MutableMapping) or "train_dataloaders" not in data: raise ValueError( diff --git a/cyto_dl/utils/__init__.py b/cyto_dl/utils/__init__.py index 7a9cdfdb9..ac3acaeab 100644 --- a/cyto_dl/utils/__init__.py +++ b/cyto_dl/utils/__init__.py @@ -1,3 +1,4 @@ +from .array import create_dataloader, extract_array_predictions from .config import kv_to_dict, remove_aux_key from .pylogger import get_pylogger from .rich_utils import enforce_tags, print_config_tree diff --git a/cyto_dl/utils/array.py b/cyto_dl/utils/array.py new file mode 100644 index 000000000..5312aa618 --- /dev/null +++ b/cyto_dl/utils/array.py @@ -0,0 +1,42 @@ +import hydra +import numpy as np +from omegaconf import OmegaConf + + +def create_dataloader(data_cfg, data=None): + """Create a dataloader from a data config and optional data.""" + data_cfg = OmegaConf.to_object(data_cfg) + if data is not None: + # inference, using make_array_dataloader + if "data" in data_cfg: + data_cfg["data"] = data + # training, has train_dataloaders/val_dataloaders + for split in ("train", "val", "test"): + if f"{split}_dataloaders" in data_cfg: + data_cfg[f"{split}_dataloaders"]["data"] = data[split] + + # Instantiate the dataloader with the dataset + dataloader = hydra.utils.instantiate(data_cfg) + + return dataloader + + +def extract_array_predictions(output, task_heads=None): + """Converts output from model.predict() to a dictionary of numpy arrays per head.""" + predictions = {} + for batch_pred in output: + # ignore io_map + _, batch_pred = batch_pred + # if no task_heads are provided, use all + if task_heads is None: + task_heads = list(batch_pred.keys()) + # combine all predictions per-head + for head in task_heads: + if head not in predictions: + predictions[head] = [] + predictions[head] += batch_pred[head]["pred"] + # stack head predictions into numpy array + for head, pred in predictions.items(): + predictions[head] = np.stack(pred) + + return predictions diff --git a/cyto_dl/utils/template_utils.py b/cyto_dl/utils/template_utils.py index 9b4664436..05291d218 100644 --- a/cyto_dl/utils/template_utils.py +++ b/cyto_dl/utils/template_utils.py @@ -4,13 +4,13 @@ import warnings from importlib.util import find_spec from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import Any, Callable, List import hydra from lightning import Callback from lightning.pytorch.loggers import Logger from lightning.pytorch.utilities import rank_zero_only -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from cyto_dl.loggers import MLFlowLogger @@ -43,14 +43,14 @@ def task_wrapper(task_func: Callable) -> Callable: - Logging the output dir """ - def wrap(cfg: DictConfig): + def wrap(cfg: DictConfig, data: Any = None): # apply extra utilities extras(cfg) # execute the task try: start_time = time.time() - out = task_func(cfg=cfg) + out = task_func(cfg=cfg, data=data) except Exception as ex: log.exception("") # save exception to `.log` file raise ex diff --git a/tests/test_array_models.py b/tests/test_array_models.py new file mode 100644 index 000000000..6d3695995 --- /dev/null +++ b/tests/test_array_models.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import numpy as np +import pytest + +from cyto_dl.api import CytoDLModel +from cyto_dl.utils import extract_array_predictions + + +@pytest.mark.skip +def test_array_train(tmp_path): + model = CytoDLModel() + + overrides = { + "trainer.max_epochs": 1, + "logger": None, + "trainer.accelerator": "cpu", + "trainer.devices": 1, + } + + model.load_default_experiment(experiment_type="segmentation_array", output_dir=tmp_path) + model.override_config(overrides) + + data = { + "train": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}], + "val": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}], + } + model.train(data=data) + + ckpt_dir = Path(model.cfg.callbacks.model_checkpoint.dirpath) + assert "last.ckpt" in [fn.name for fn in ckpt_dir.iterdir()] + return ckpt_dir / "last.ckpt" + + +@pytest.mark.slow +def test_array_train_predict(tmp_path): + ckpt_path = test_array_train(tmp_path) + + model = CytoDLModel() + + overrides = { + "logger": None, + "trainer.accelerator": "cpu", + "trainer.devices": 1, + "ckpt_path": ckpt_path, + } + + model.load_default_experiment( + experiment_type="segmentation_array", + output_dir=tmp_path, + train=False, + overrides=["data=im2im/numpy_dataloader_predict"], + ) + model.override_config(overrides) + model.print_config() + + data = [np.random.rand(1, 32, 64, 64), np.random.rand(1, 32, 64, 64)] + _, _, output = model.predict(data=data) + preds = extract_array_predictions(output) + + for head in model.cfg.model.task_heads.keys(): + assert preds[head].shape[0] == len(data) + assert preds[head].shape[1:] == data[0].shape