Skip to content

Commit

Permalink
Feature/array dataloaders (#410)
Browse files Browse the repository at this point in the history
* add array utils

* add dataloader

* change no overlap, display progress

* add numpy array experiment to enum

* precommit

* add dataloader creation function

* add array configs

* add test

* update docs

* remove prediction saving callback

* remove aux section

* update docstring

* update docs

* precommit

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Aug 7, 2024
1 parent 8e05a63 commit 7668cac
Show file tree
Hide file tree
Showing 21 changed files with 359 additions and 23 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@ model.train()
await model.train(run_async=True)
```

Most models work by passing data paths in the data config. For training or predicting on datasets that are already in memory, you can pass the data directly to the model. Note that this use case is primarily for programmatic use (e.g. in a workflow or a jupyter notebook), not through the normal CLI. An experiment showing a possible config setup for this use case is demonstrated with the [im2im/segmentation_array](configs/experiment/im2im/segmentation_array.yaml) experiment. For training, data must be passed as a dictionary with keys "train" and "val" containing lists of dictionaries with keys corresponding to the data config.

```python
from cyto_dl.api import CytoDLModel
import numpy as np

model = CytoDLModel()
model.load_default_experiment("segmentation_array", output_dir="./output")
model.print_config()

# create CZYX dummy data
data = {
"train": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}],
"val": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}],
}
model.train(data=data)
```

For predicting, data must be passed as a list of numpy arrays. The resulting predictions will be processed in a dictionary with one key for each task head in the model config and corresponding values in BC(Z)YX order.

```python
from cyto_dl.api import CytoDLModel
import numpy as np
from cyto_dl.utils import extract_array_predictions

model = CytoDLModel()
model.load_default_experiment(
"segmentation_array", output_dir="./output", overrides=["data=im2im/numpy_dataloader_predict"]
)
model.print_config()

# create CZYX dummy data
data = [np.random.rand(1, 32, 64, 64), np.random.rand(1, 32, 64, 64)]

_, _, output = model.predict(data=data)
preds = extract_array_predictions(output)
```

Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)

```bash
Expand Down
16 changes: 16 additions & 0 deletions configs/data/im2im/numpy_dataloader_predict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: cyto_dl.datamodules.array.make_array_dataloader
data:
num_workers: 1
batch_size: 1
source_key: ${source_col}
transforms:
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- _target_: cyto_dl.image.transforms.clip.Clipd
keys:
- ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
channel_wise: true
keys:
- ${source_col}
79 changes: 79 additions & 0 deletions configs/data/im2im/numpy_dataloader_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
_aux:
patch_shape:
_scales_dict:
- - ${target_col}
- [1]
- - ${source_col}
- [1]

train_dataloaders:
_target_: cyto_dl.datamodules.array.make_array_dataloader
data:
num_workers: 0
batch_size: 1
source_key: ${source_col}
transforms:
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- ${target_col}
- _target_: cyto_dl.image.transforms.clip.Clipd
keys: ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: true
- _target_: monai.transforms.ThresholdIntensityd
keys: ${target_col}
threshold: 0.1
above: False
cval: 1
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
- ${target_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
- _target_: monai.transforms.RandHistogramShiftd
prob: 0.1
keys: ${source_col}
num_control_points: [90, 500]

- _target_: monai.transforms.RandStdShiftIntensityd
prob: 0.1
keys: ${source_col}
factors: 0.1

- _target_: monai.transforms.RandAdjustContrastd
prob: 0.1
keys: ${source_col}
gamma: [0.9, 1.5]

val_dataloaders:
_target_: cyto_dl.datamodules.array.make_array_dataloader
data:
num_workers: 0
batch_size: 1
source_key: ${source_col}
transforms:
- _target_: monai.transforms.ToTensord
keys:
- ${source_col}
- ${target_col}
- _target_: cyto_dl.image.transforms.clip.Clipd
keys: ${source_col}
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: true
- _target_: monai.transforms.ThresholdIntensityd
keys: ${target_col}
threshold: 0.1
above: False
cval: 1
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
keys:
- ${source_col}
- ${target_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
40 changes: 40 additions & 0 deletions configs/experiment/im2im/segmentation_array.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=example
defaults:
- override /data: im2im/numpy_dataloader_train.yaml
- override /model: im2im/segmentation.yaml
- override /callbacks: default.yaml
- override /trainer: gpu.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["dev"]
seed: 12345

experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME
source_col: raw
target_col: seg
spatial_dims: 3
raw_im_channels: 1

trainer:
max_epochs: 100

data:
_aux:
# 2D
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
3 changes: 2 additions & 1 deletion configs/model/im2im/gan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/gan_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/instance_seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/labelfree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
mode: "gaussian"
progress: True

_aux:
_tasks:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/segmentation_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
3 changes: 2 additions & 1 deletion configs/model/im2im/vit_segmentation_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ lr_scheduler:
inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0.25
overlap: 0
progress: True
mode: "gaussian"

_aux:
Expand Down
1 change: 1 addition & 0 deletions cyto_dl/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ExperimentType(Enum):
LABEL_FREE = "labelfree"
SEGMENTATION_PLUGIN = "segmentation_plugin"
SEGMENTATION = "segmentation"
SEGMENTATION_ARRAY = "segmentation_array"


class HardwareType(Enum):
Expand Down
8 changes: 4 additions & 4 deletions cyto_dl/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ async def _train_async(self):
async def _predict_async(self):
return evaluate(self.cfg)

def train(self, run_async=False):
def train(self, run_async=False, data=None):
if self.cfg is None:
raise ValueError("Configuration must be loaded before training!")
if run_async:
return self._train_async()
return train_model(self.cfg)
return train_model(self.cfg, data)

def predict(self, run_async=False):
def predict(self, run_async=False, data=None):
if self.cfg is None:
raise ValueError("Configuration must be loaded before predicting!")
if run_async:
return self._predict_async()
return evaluate(self.cfg)
return evaluate(self.cfg, data)
48 changes: 48 additions & 0 deletions cyto_dl/datamodules/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Callable, Dict, List, Sequence, Union

import numpy as np
from monai.data import DataLoader, Dataset
from monai.transforms import Compose
from omegaconf import ListConfig, OmegaConf


def make_array_dataloader(
data: Union[np.ndarray, List[np.ndarray], List[Dict[str, np.ndarray]]],
transforms: Union[Sequence[Callable], Callable],
source_key: str = "input",
**dataloader_kwargs,
):
"""Create a dataloader from a an array dataset.
Parameters
----------
data: Union[np.ndarray, List[np.ndarray], List[Dict[str, np.ndarray]],
If a numpy array (prediction only), the dataloader will be created with a single source_key.
If a list each element must be a numpy array (for prediction) or a dictionary containing numpy array values (for training).
transforms: Union[Sequence[Callable], Callable],
Transforms to apply to each sample
dataloader_kwargs:
Additional keyword arguments are passed to the
torch.utils.data.DataLoader class when instantiating it (aside from
`shuffle` which is only used for the train dataloader).
Among these args are `num_workers`, `batch_size`, `shuffle`, etc.
See the PyTorch docs for more info on these args:
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
"""
if isinstance(transforms, (list, tuple, ListConfig)):
transforms = Compose(transforms)
data = OmegaConf.to_object(data)
if isinstance(data, (list, tuple, ListConfig)):
data = [{source_key: d} if isinstance(d, np.ndarray) else d for d in data]
elif isinstance(data, np.ndarray):
data = [{source_key: data}]
else:
raise ValueError(
f"Invalid data type: {type(data)}. Data must be a numpy array or list of numpy arrays."
)

dataset = Dataset(data, transform=transforms)

return DataLoader(dataset, **dataloader_kwargs)
10 changes: 5 additions & 5 deletions cyto_dl/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import hydra
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from monai.data import DataLoader as MonaiDataLoader
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader

from cyto_dl import utils

Expand All @@ -19,7 +20,7 @@


@utils.task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[dict, dict, dict]:
def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
Expand All @@ -40,9 +41,8 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict, dict]:

# remove aux section after resolving and before instantiating
utils.remove_aux_key(cfg)

data = hydra.utils.instantiate(cfg.data)
if not isinstance(data, (LightningDataModule, DataLoader)):
data = utils.create_dataloader(cfg.data, data)
if not isinstance(data, (LightningDataModule, TorchDataLoader, MonaiDataLoader)):
if isinstance(data, MutableMapping) and not data.dataloaders:
raise ValueError(
"If the `data` config for eval/prediction is a dict it must have a "
Expand Down
4 changes: 2 additions & 2 deletions cyto_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
Expand Down Expand Up @@ -57,7 +57,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
utils.remove_aux_key(cfg)

log.info(f"Instantiating data <{cfg.data.get('_target_', cfg.data)}>")
data = hydra.utils.instantiate(cfg.data)
data = utils.create_dataloader(cfg.data, data)
if not isinstance(data, LightningDataModule):
if not isinstance(data, MutableMapping) or "train_dataloaders" not in data:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions cyto_dl/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .array import create_dataloader, extract_array_predictions
from .config import kv_to_dict, remove_aux_key
from .pylogger import get_pylogger
from .rich_utils import enforce_tags, print_config_tree
Expand Down
Loading

0 comments on commit 7668cac

Please sign in to comment.