Skip to content

Commit

Permalink
Recalculate crop size if user-specified crop size indivisible by max …
Browse files Browse the repository at this point in the history
…stride (#841)
  • Loading branch information
roomrys authored Jul 19, 2022
1 parent fefd2b5 commit d522f1d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
11 changes: 8 additions & 3 deletions sleap/gui/learning/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def show_win(


def make_datagen_results(reader: LabelsReader, cfg: TrainingJobConfig) -> np.ndarray:
"""
Gets (subset of) raw images used for training.
"""Get (subset of) raw images used for training.
TODO: Refactor so we can get this data without digging into details of the
the specific pipelines (e.g., key for confmaps depends on head type).
Expand Down Expand Up @@ -113,11 +112,17 @@ def make_datagen_results(reader: LabelsReader, cfg: TrainingJobConfig) -> np.nda
output_keys["confmap"] = "centroid_confidence_maps"

elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig):
if cfg.data.instance_cropping.crop_size is None:
if (cfg.data.instance_cropping.crop_size is None) or (
cfg.data.instance_cropping.crop_size
% cfg.model.backbone.which_oneof().max_stride
> 0
):
# Compute crop size that is divisible by max stride
cfg.data.instance_cropping.crop_size = find_instance_crop_size(
labels=reader.labels,
padding=cfg.data.instance_cropping.crop_size_detection_padding,
maximum_stride=cfg.model.backbone.which_oneof().max_stride,
min_crop_size=cfg.data.instance_cropping.crop_size,
)

pipeline += pipelines.InstanceCentroidFinder.from_config(
Expand Down
13 changes: 12 additions & 1 deletion sleap/nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def find_instance_crop_size(
padding: int = 0,
maximum_stride: int = 2,
input_scaling: float = 1.0,
min_crop_size: Optional[int] = None,
) -> int:
"""Compute the size of the largest instance bounding box from labels.
Expand All @@ -24,19 +25,29 @@ def find_instance_crop_size(
architecture.
input_scaling: Float factor indicating the scale of the input images if any
scaling will be done before cropping.
min_crop_size: The (optional) crop size set by the user. None if not set.
Returns:
An integer crop size denoting the length of the side of the bounding boxes that
will contain the instances when cropped.
will contain the instances when cropped. The returned crop size will be larger
or equal to the input `crop_size`.
This accounts for stride, padding and scaling when ensuring divisibility.
"""
# Check if user-specified crop size is divisible by max stride
min_crop_size = 0 if min_crop_size is None else min_crop_size
if (min_crop_size > 0) and (min_crop_size % maximum_stride == 0):
return min_crop_size

# Calculate crop size
min_crop_size_no_pad = min_crop_size - padding
max_length = 0.0
for inst in labels.user_instances:
pts = inst.points_array
pts *= input_scaling
max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]))
max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]))
max_length = np.maximum(max_length, min_crop_size_no_pad)

max_length += float(padding)
crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride
Expand Down
12 changes: 10 additions & 2 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,12 +1224,16 @@ def _update_config(self):
if self.config.data.preprocessing.pad_to_stride is None:
self.config.data.preprocessing.pad_to_stride = 1

if self.config.data.instance_cropping.crop_size is None:
if (self.config.data.instance_cropping.crop_size is None) or (
self.config.data.instance_cropping.crop_size % self.model.maximum_stride > 0
):
# Compute crop size that is divisible by max stride
self.config.data.instance_cropping.crop_size = sleap.nn.data.instance_cropping.find_instance_crop_size(
self.data_readers.training_labels,
padding=self.config.data.instance_cropping.crop_size_detection_padding,
maximum_stride=self.model.maximum_stride,
input_scaling=self.config.data.preprocessing.input_scaling,
min_crop_size=self.config.data.instance_cropping.crop_size,
)

if self.config.optimization.batches_per_epoch is None:
Expand Down Expand Up @@ -1631,12 +1635,16 @@ def _update_config(self):
if self.config.data.preprocessing.pad_to_stride is None:
self.config.data.preprocessing.pad_to_stride = self.model.maximum_stride

if self.config.data.instance_cropping.crop_size is None:
if (self.config.data.instance_cropping.crop_size is None) or (
self.config.data.instance_cropping.crop_size % self.model.maximum_stride > 0
):
# Compute crop size which is divisible by max stride
self.config.data.instance_cropping.crop_size = sleap.nn.data.instance_cropping.find_instance_crop_size(
self.data_readers.training_labels,
padding=self.config.data.instance_cropping.crop_size_detection_padding,
maximum_stride=self.model.maximum_stride,
input_scaling=self.config.data.preprocessing.input_scaling,
min_crop_size=self.config.data.instance_cropping.crop_size,
)

if self.config.optimization.batches_per_epoch is None:
Expand Down
28 changes: 28 additions & 0 deletions tests/nn/test_training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import sleap
from sleap.io.dataset import Labels
from sleap.nn.config.data import LabelsConfig
from sleap.nn.config.model import (
CenteredInstanceConfmapsHeadConfig,
Expand All @@ -16,6 +17,8 @@
DataReaders,
SingleInstanceModelTrainer,
TopdownConfmapsModelTrainer,
TopDownMultiClassModelTrainer,
Trainer,
)

sleap.use_cpu_only()
Expand Down Expand Up @@ -227,3 +230,28 @@ def test_train_topdown_multiclass(min_tracks_2node_labels, cfg):
assert trainer.keras_model.output_names[1] == "ClassVectorsHead"
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 64, 64, 2)
assert tuple(trainer.keras_model.outputs[1].shape) == (None, 2)


@pytest.mark.parametrize(
"trainer_class", [TopdownConfmapsModelTrainer, TopDownMultiClassModelTrainer]
)
def test_train_cropping(
training_labels: Labels, cfg: TrainingJobConfig, trainer_class: Trainer
):
# Set model head
cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
sigma=1.5, output_stride=1, offset_refinement=False
)

# Create trainer
trainer = trainer_class.from_config(cfg, training_labels=training_labels)

# Change trainer.config s.t. crop size not divisible by max stride
trainer.config.data.instance_cropping.crop_size = trainer.model.maximum_stride + 1

# Ensure crop size is updated to be divisible by max stride
trainer._update_config()
assert (
trainer.config.data.instance_cropping.crop_size % trainer.model.maximum_stride
== 0
)

0 comments on commit d522f1d

Please sign in to comment.