From f889ff152d1de9ae14e3abcef746d259fd6b2cea Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Tue, 31 Oct 2023 12:54:10 -0700 Subject: [PATCH] document data methods (#50) --- viscy/light/data.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index 41eb62f6..1af34242 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -30,6 +30,12 @@ def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]): + """ + Ensure channel argument is a list of strings. + + :param Union[str, Sequence[str]] str_or_seq: channel name or list of channel names + :return list[str]: list of channel names + """ if isinstance(str_or_seq, str): return [str_or_seq] try: @@ -52,12 +58,16 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: class ChannelMap(TypedDict, total=False): + """Source and target channel names.""" + source: Union[str, Sequence[str]] # optional target: Union[str, Sequence[str]] class Sample(TypedDict, total=False): + """Image sample type for mini-batches.""" + index: tuple[str, int, int] # optional source: Union[torch.Tensor, Sequence[torch.Tensor]] @@ -66,6 +76,13 @@ class Sample(TypedDict, total=False): def _collate_samples(batch: Sequence[Sample]) -> Sample: + """Collate samples into a batch sample. + + :param Sequence[Sample] batch: a sequence of dictionaries, + where each key may point to a value of a single tensor or a list of tensors, + as is the case with ``train_patches_per_stack > 1``. + :return Sample: Batch sample (dictionary of tensors) + """ elemment = batch[0] collated = {} for key in elemment.keys(): @@ -190,6 +207,7 @@ def __len__(self) -> int: def _stack_channels( self, sample_images: list[dict[str, torch.Tensor]], key: str ) -> torch.Tensor: + """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) # training time @@ -410,6 +428,7 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): raise NotImplementedError(f"{stage} stage") def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: + """Setup stages where the target is available (evaluating performance).""" dataset_settings["channels"]["target"] = self.target_channel data_path = self.tmp_zarr if self.tmp_zarr else self.data_path plate = open_ome_zarr(data_path, mode="r") @@ -426,6 +445,7 @@ def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: return plate, normalize_transform def _setup_fit(self, dataset_settings: dict): + """Set up the training and validation datasets.""" plate, normalize_transform = self._setup_eval(dataset_settings) fit_transform = self._fit_transform() train_transform = Compose( @@ -456,7 +476,8 @@ def _setup_fit(self, dataset_settings: dict): positions[num_train_fovs:], transform=val_transform, **dataset_settings ) - def _setup_test(self, dataset_settings): + def _setup_test(self, dataset_settings: dict): + """Set up the test stage.""" if self.batch_size != 1: logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") plate, normalize_transform = self._setup_eval(dataset_settings) @@ -475,6 +496,7 @@ def _setup_test(self, dataset_settings): ) def _setup_predict(self, dataset_settings: dict): + """Set up the predict stage.""" # track metadata for inverting transform set_track_meta(True) if self.caching: @@ -495,6 +517,7 @@ def _setup_predict(self, dataset_settings: dict): ) def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: + """Removes redundant Z slices if the target is 2D to save VRAM.""" predicting = False if self.trainer: if self.trainer.predicting: @@ -544,6 +567,7 @@ def predict_dataloader(self): ) def _fit_transform(self): + """Deterministic center crop as the last step.""" return [ CenterSpatialCropd( keys=self.source_channel + self.target_channel, @@ -556,6 +580,7 @@ def _fit_transform(self): ] def _train_transform(self) -> list[Callable]: + """Random crop sampling and augmentation for training.""" transforms = [ RandWeightedCropd( keys=self.source_channel + self.target_channel,