Skip to content

Commit

Permalink
document data methods (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu authored Oct 31, 2023
1 parent 2464a95 commit f889ff1
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion viscy/light/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@


def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]):
"""
Ensure channel argument is a list of strings.
:param Union[str, Sequence[str]] str_or_seq: channel name or list of channel names
:return list[str]: list of channel names
"""
if isinstance(str_or_seq, str):
return [str_or_seq]
try:
Expand All @@ -52,12 +58,16 @@ def _search_int_in_str(pattern: str, file_name: str) -> str:


class ChannelMap(TypedDict, total=False):
"""Source and target channel names."""

source: Union[str, Sequence[str]]
# optional
target: Union[str, Sequence[str]]


class Sample(TypedDict, total=False):
"""Image sample type for mini-batches."""

index: tuple[str, int, int]
# optional
source: Union[torch.Tensor, Sequence[torch.Tensor]]
Expand All @@ -66,6 +76,13 @@ class Sample(TypedDict, total=False):


def _collate_samples(batch: Sequence[Sample]) -> Sample:
"""Collate samples into a batch sample.
:param Sequence[Sample] batch: a sequence of dictionaries,
where each key may point to a value of a single tensor or a list of tensors,
as is the case with ``train_patches_per_stack > 1``.
:return Sample: Batch sample (dictionary of tensors)
"""
elemment = batch[0]
collated = {}
for key in elemment.keys():
Expand Down Expand Up @@ -190,6 +207,7 @@ def __len__(self) -> int:
def _stack_channels(
self, sample_images: list[dict[str, torch.Tensor]], key: str
) -> torch.Tensor:
"""Stack single-channel images into a multi-channel tensor."""
if not isinstance(sample_images, list):
return torch.stack([sample_images[ch][0] for ch in self.channels[key]])
# training time
Expand Down Expand Up @@ -410,6 +428,7 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
raise NotImplementedError(f"{stage} stage")

def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]:
"""Setup stages where the target is available (evaluating performance)."""
dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")
Expand All @@ -426,6 +445,7 @@ def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]:
return plate, normalize_transform

def _setup_fit(self, dataset_settings: dict):
"""Set up the training and validation datasets."""
plate, normalize_transform = self._setup_eval(dataset_settings)
fit_transform = self._fit_transform()
train_transform = Compose(
Expand Down Expand Up @@ -456,7 +476,8 @@ def _setup_fit(self, dataset_settings: dict):
positions[num_train_fovs:], transform=val_transform, **dataset_settings
)

def _setup_test(self, dataset_settings):
def _setup_test(self, dataset_settings: dict):
"""Set up the test stage."""
if self.batch_size != 1:
logging.warning(f"Ignoring batch size {self.batch_size} in test stage.")
plate, normalize_transform = self._setup_eval(dataset_settings)
Expand All @@ -475,6 +496,7 @@ def _setup_test(self, dataset_settings):
)

def _setup_predict(self, dataset_settings: dict):
"""Set up the predict stage."""
# track metadata for inverting transform
set_track_meta(True)
if self.caching:
Expand All @@ -495,6 +517,7 @@ def _setup_predict(self, dataset_settings: dict):
)

def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
"""Removes redundant Z slices if the target is 2D to save VRAM."""
predicting = False
if self.trainer:
if self.trainer.predicting:
Expand Down Expand Up @@ -544,6 +567,7 @@ def predict_dataloader(self):
)

def _fit_transform(self):
"""Deterministic center crop as the last step."""
return [
CenterSpatialCropd(
keys=self.source_channel + self.target_channel,
Expand All @@ -556,6 +580,7 @@ def _fit_transform(self):
]

def _train_transform(self) -> list[Callable]:
"""Random crop sampling and augmentation for training."""
transforms = [
RandWeightedCropd(
keys=self.source_channel + self.target_channel,
Expand Down

0 comments on commit f889ff1

Please sign in to comment.