Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dry run for loaders to check for valid instances #705

Merged
merged 7 commits into from
Oct 8, 2024
Merged
93 changes: 57 additions & 36 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import time
import warnings
from contextlib import contextmanager, nullcontext
from glob import glob
from tqdm import tqdm
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import imageio.v3 as imageio
Expand Down Expand Up @@ -32,8 +33,8 @@
FilePath = Union[str, os.PathLike]


def _check_loader(loader, with_segmentation_decoder):
x, y = next(iter(loader))
def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
x, _ = next(iter(loader))

# Raw data: check that we have 1 or 3 channels.
n_channels = x.shape[1]
Expand All @@ -57,8 +58,9 @@ def _check_loader(loader, with_segmentation_decoder):
)

# Target data: the check depends on whether we train with or without decoder.
# NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).

def check_instance_channel(instance_channel):
def _check_instance_channel(instance_channel):
unique_vals = torch.unique(instance_channel)
if (unique_vals < 0).any():
raise ValueError(
Expand All @@ -73,38 +75,53 @@ def check_instance_channel(instance_channel):
"All values in the target channel with the instance segmentation must be integer."
)

n_channels_y = y.shape[1]
if with_segmentation_decoder:
if n_channels_y != 4:
raise ValueError(
"Invalid number of channels in the target data from the data loader. "
"Expect 4 channel for training with an instance segmentation decoder, "
f"but got {n_channels_y} channels."
)
check_instance_channel(y[:, 0])
counter = 0
name = "" if name is None else f"'{name}'"
for x, y in tqdm(
loader,
desc=f"Verifying labels in {name} dataloader",
total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
):
n_channels_y = y.shape[1]
if with_segmentation_decoder:
if n_channels_y != 4:
raise ValueError(
"Invalid number of channels in the target data from the data loader. "
"Expect 4 channel for training with an instance segmentation decoder, "
f"but got {n_channels_y} channels."
)
# Check instance channel per sample in a batch
for per_y_sample in y:
_check_instance_channel(per_y_sample[0])

targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
if targets_min < 0 or targets_min > 1:
raise ValueError(
"Invalid value range in the target data from the value loader. "
"Expect the 3 last target channels (for normalized distances and foreground probabilities)"
f"to be in range [0.0, 1.0], but got min {targets_min}"
)
if targets_max < 0 or targets_max > 1:
raise ValueError(
"Invalid value range in the target data from the value loader. "
"Expect the 3 last target channels (for normalized distances and foreground probabilities)"
f"to be in range [0.0, 1.0], but got max {targets_max}"
)

targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
if targets_min < 0 or targets_min > 1:
raise ValueError(
"Invalid value range in the target data from the value loader. "
"Expect the 3 last target channels (for normalized distances and foreground probabilities)"
f"to be in range [0.0, 1.0], but got min {targets_min}"
)
if targets_max < 0 or targets_max > 1:
raise ValueError(
"Invalid value range in the target data from the value loader. "
"Expect the 3 last target channels (for normalized distances and foreground probabilities)"
f"to be in range [0.0, 1.0], but got max {targets_max}"
)
else:
if n_channels_y != 1:
raise ValueError(
"Invalid number of channels in the target data from the data loader. "
"Expect 1 channel for training without an instance segmentation decoder,"
f"but got {n_channels_y} channels."
)
# Check instance channel per sample in a batch
for per_y_sample in y:
_check_instance_channel(per_y_sample)

else:
if n_channels_y != 1:
raise ValueError(
"Invalid number of channels in the target data from the data loader. "
"Expect 1 channel for training without an instance segmentation decoder,"
f"but got {n_channels_y} channels."
)
check_instance_channel(y)
counter += 1
if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
break


# Make the progress bar callbacks compatible with a tqdm progress bar interface.
Expand Down Expand Up @@ -170,6 +187,7 @@ def train_sam(
optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
peft_kwargs: Optional[Dict] = None,
ignore_warnings: bool = True,
verify_n_labels_in_loader: Optional[int] = 50,
**model_kwargs,
) -> None:
"""Run training for a SAM model.
Expand Down Expand Up @@ -208,14 +226,17 @@ def train_sam(
optimizer_class: The optimizer class.
By default, torch.optim.AdamW is used.
peft_kwargs: Keyword arguments for the PEFT wrapper class.
verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
By default, 50 batches of labels are verified from the dataloaders.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
ignore_warnings: Whether to ignore raised warnings.
"""
with _filter_warnings(ignore_warnings):

t_start = time.time()

_check_loader(train_loader, with_segmentation_decoder)
_check_loader(val_loader, with_segmentation_decoder)
_check_loader(train_loader, with_segmentation_decoder, verify_n_labels_in_loader)
_check_loader(val_loader, with_segmentation_decoder, verify_n_labels_in_loader)

device = get_device(device)
# Get the trainable segment anything model.
Expand Down
Loading