Skip to content

Commit

Permalink
SmartCache Bugfix (#352)
Browse files Browse the repository at this point in the history
* update train valsplitting

* update cache during training

* add dtype arg

* remove smartcache print

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Mar 13, 2024
1 parent 07e0b9a commit 61d2015
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 33 deletions.
106 changes: 74 additions & 32 deletions cyto_dl/datamodules/smartcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import dask
import numpy as np
import pandas as pd
import tqdm
from aicsimageio import AICSImage
from dask.diagnostics import ProgressBar
from lightning import LightningDataModule
from monai.data import DataLoader
from monai.data.dataset import Dataset, SmartCacheDataset
from monai.data.dataset import CacheDataset, Dataset, SmartCacheDataset
from monai.transforms import Compose
from sklearn.model_selection import train_test_split

Expand All @@ -19,17 +18,18 @@ class SmartcacheDatamodule(LightningDataModule):

def __init__(
self,
csv_path: Union[Path, str],
csv_path: Optional[Union[Path, str]] = None,
transforms: Compose = None,
img_data: Optional[Union[Path, str]] = None,
n_val: int = 20,
pct_val: float = 0.1,
img_path_column: str = "raw",
channel_column: str = "ch",
spatial_dims: int = 3,
neighboring_timepoints: bool = False,
num_neighbors: int = 0,
num_workers: int = 4,
cache_rate: float = 0.5,
replace_rate: float = 0.1,
**kwargs,
):
"""
Expand All @@ -51,33 +51,49 @@ def __init__(
column in csv_path that contains the channel to use
spatial_dims: int
number of spatial dimensions in the image
neighboring_timepoints: bool
whether to return T and T+1 as a 2 channel image, useful for models incorporating time
num_neighbors: int
number of neighboring timepoints to use
num_workers: int
number of workers to use for loading data. Most be specified here to schedule replacement workers for cache data
cache_rate: float
percentage of data to cache
replace_rate: float
percentage of data to replace
kwargs:
additional arguments to pass to DataLoader
"""
super().__init__()
self.csv_path = Path(csv_path)
(self.csv_path.parents[0] / "loaded_data").mkdir(exist_ok=True, parents=True)
self.df = pd.read_csv(csv_path)
self.img_data = {}
if isinstance(img_data, (str, Path)):
# read img_data if it's a path, otherwise set to empty dict
self.img_data["train"] = [
row._asdict()
for row in pd.read_csv(Path(img_data) / "train_img_data.csv").itertuples()
]
self.img_data["val"] = [
row._asdict()
for row in pd.read_csv(Path(img_data) / "val_img_data.csv").itertuples()
]
elif csv_path is not None:
self.csv_path = Path(csv_path)
(self.csv_path.parents[0] / "loaded_data").mkdir(exist_ok=True, parents=True)
self.df = pd.read_csv(csv_path)
else:
raise ValueError("csv_path or img_data must be specified")
self.num_workers = num_workers
self.kwargs = kwargs
val_size = np.min([n_val, int(len(self.df) * pct_val)])
self.val_size = np.max([val_size, 1])
# read img_data if it's a path, otherwise set to empty dict
if isinstance(img_data, str):
self.img_data = pd.read_csv(img_data)
else:
self.img_data = {}

self.n_val = n_val
self.pct_val = pct_val

self.datasets = {}
self.img_path_column = img_path_column
self.channel_column = channel_column
self.spatial_dims = spatial_dims
self.transforms = transforms
self.neighboring_timepoints = neighboring_timepoints
self.num_neighbors = num_neighbors
self.cache_rate = cache_rate
self.replace_rate = replace_rate

def _get_scenes(self, img):
"""Get the number of scenes in an image."""
Expand All @@ -86,8 +102,8 @@ def _get_scenes(self, img):
def _get_timepoints(self, img):
"""Get the number of timepoints in an image."""
timepoints = list(range(img.dims.T))
if self.neighboring_timepoints:
return timepoints[:-1]
if self.num_neighbors > 0:
return timepoints[: -self.num_neighbors]
return timepoints

@dask.delayed
Expand All @@ -97,18 +113,19 @@ def _get_file_args(self, row):
scenes = self._get_scenes(img)
timepoints = self._get_timepoints(img)
img_data = []
use_neighbors = self.num_neighbors > 0
for scene in scenes:
for timepoint in timepoints:
img_data.append(
{
"dimension_order_out": "ZYX"[-self.spatial_dims :]
if not self.neighboring_timepoints
if not use_neighbors
else "T" + "ZYX"[-self.spatial_dims :],
"C": row[self.channel_column],
"scene": scene,
"T": timepoint
if not self.neighboring_timepoints
else [timepoint, timepoint + 1],
if not use_neighbors
else [timepoint + i for i in range(self.num_neighbors + 1)],
"original_path": row[self.img_path_column],
}
)
Expand All @@ -127,11 +144,33 @@ def prepare_data(self):

def setup(self, stage=None):
if stage == "fit":
# split df into train/test/val
self.df_train, self.df_val = train_test_split(self.df, test_size=self.val_size)
if "train" in self.img_data and "val" in self.img_data:
self.datasets["train"] = SmartCacheDataset(
self.img_data["train"],
transform=self.transforms["train"],
cache_rate=self.cache_rate,
num_replace_workers=2,
num_init_workers=self.num_workers,
replace_rate=self.replace_rate,
)
self.datasets["val"] = CacheDataset(
self.img_data["val"],
transform=self.transforms["valid"],
cache_rate=0.02, # 1.0,
num_workers=self.num_workers,
)
return
# update img_data
self.img_data["train"] = self.get_per_file_args(self.df_train)
self.img_data["val"] = self.get_per_file_args(self.df_val)
image_data = self.get_per_file_args(self.df)
val_size = np.min([self.n_val, int(len(image_data) * self.pct_val)])
val_size = np.max([val_size, 1])
self.img_data["train"], self.img_data["val"] = train_test_split(
image_data, test_size=val_size
)

print("Train images:", len(self.img_data["train"]))
print("Val images:", len(self.img_data["val"]))

pd.DataFrame(self.img_data["train"]).to_csv(
f"{self.csv_path.parents[0]}/loaded_data/train_img_data.csv",
index=False,
Expand All @@ -144,13 +183,15 @@ def setup(self, stage=None):
self.img_data["train"],
transform=self.transforms["train"],
cache_rate=self.cache_rate,
num_replace_workers=self.num_workers // 2,
num_replace_workers=self.num_workers,
num_init_workers=self.num_workers,
replace_rate=self.replace_rate,
)
self.datasets["val"] = SmartCacheDataset(
self.datasets["val"] = CacheDataset(
self.img_data["val"],
transform=self.transforms["val"],
cache_rate=self.cache_rate,
num_replace_workers=self.num_workers // 2,
transform=self.transforms["valid"],
cache_rate=1.0,
num_workers=self.num_workers,
)

elif stage in ("test", "predict"):
Expand All @@ -160,7 +201,8 @@ def setup(self, stage=None):
def make_dataloader(self, split):
# smartcachedataset can't have persistent workers
self.kwargs["persistent_workers"] = split not in ("train", "val")
self.kwargs.pop("num_workers")
if "num_workers" in self.kwargs:
del self.kwargs["num_workers"]
return DataLoader(
self.datasets[split],
num_workers=self.num_workers,
Expand Down
5 changes: 4 additions & 1 deletion cyto_dl/image/io/aicsimage_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import numpy as np
from aicsimageio import AICSImage
from monai.data import MetaTensor
from monai.transforms import Transform
Expand All @@ -20,6 +21,7 @@ def __init__(
kwargs_keys: List = ["dimension_order_out", "C", "T"],
out_key: str = "raw",
allow_missing_keys=False,
dtype: np.dtype = np.float16,
):
"""
Parameters
Expand All @@ -41,6 +43,7 @@ def __init__(
self.allow_missing_keys = allow_missing_keys
self.out_key = out_key
self.scene_key = scene_key
self.dtype = dtype

def __call__(self, data):
# copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time
Expand All @@ -52,7 +55,7 @@ def __call__(self, data):
if self.scene_key in data:
img.set_scene(data[self.scene_key])
kwargs = {k: data[k] for k in self.kwargs_keys}
img = img.get_image_dask_data(**kwargs).compute()
img = img.get_image_dask_data(**kwargs).compute().astype(self.dtype)
data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs})

return data
14 changes: 14 additions & 0 deletions cyto_dl/models/im2im/multi_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from contextlib import suppress
from pathlib import Path
from typing import Dict, List, Union

Expand Down Expand Up @@ -220,3 +221,16 @@ def predict_step(self, batch, batch_idx):
save_image = self.should_save_image(batch_idx, stage)
self.run_forward(batch, stage, save_image, run_heads)
return io_map

# utils for smartcache training
def on_train_start(self):
with suppress(AttributeError):
self.trainer.datamodule.train_dataloader().dataset.start()

def on_train_epoch_end(self):
with suppress(AttributeError):
self.trainer.datamodule.train_dataloader().dataset.update_cache()

def on_train_end(self, *args, **kwargs):
with suppress(AttributeError):
self.trainer.datamodule.train_dataloader().dataset.shutdown()

0 comments on commit 61d2015

Please sign in to comment.