diff --git a/sleap/nn/config/data.py b/sleap/nn/config/data.py index 2c14cc645..bc091a547 100644 --- a/sleap/nn/config/data.py +++ b/sleap/nn/config/data.py @@ -76,6 +76,13 @@ class PreprocessingConfig: max stride (typically 32). This padding will be ignored when instance cropping inputs since the crop size should already be divisible by the model's max stride. + resize_and_pad_to_target: If True, will resize and pad all images in the dataset to match target dimensions. + This is useful when preprocessing datasets with mixed image dimensions (from different video resolutions). + Aspect ratio is preserved, and padding applied (if needed) to bottom or right of image only. + target_height: Target image height for 'resize_and_pad_to_target'. When not explicitly provided, inferred as the + max image height from the dataset. + target_width: Target image width for 'resize_and_pad_to_target'. When not explicitly provided, inferred as the + max image width from the dataset. """ ensure_rgb: bool = False @@ -88,6 +95,9 @@ class PreprocessingConfig: ) input_scaling: float = 1.0 pad_to_stride: Optional[int] = None + resize_and_pad_to_target: bool = True + target_height: Optional[int] = None + target_width: Optional[int] = None @attr.s(auto_attribs=True) diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index eb7b51c2f..d8ee2a5a5 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -22,7 +22,7 @@ RandomCropper, ) from sleap.nn.data.normalization import Normalizer -from sleap.nn.data.resizing import Resizer, PointsRescaler +from sleap.nn.data.resizing import Resizer, PointsRescaler, SizeMatcher from sleap.nn.data.instance_centroids import InstanceCentroidFinder from sleap.nn.data.instance_cropping import InstanceCropper, PredictedInstanceCropper from sleap.nn.data.confidence_maps import ( @@ -68,6 +68,7 @@ RandomCropper, Normalizer, Resizer, + SizeMatcher, InstanceCentroidFinder, InstanceCropper, MultiConfidenceMapGenerator, @@ -123,7 +124,7 @@ def from_blocks( """Create a pipeline from a sequence of providers and transformers. Args: - sequence: List or tuple of providers and transformer instances. + blocks: List or tuple of providers and transformer instances. Returns: An instantiated pipeline with all blocks chained. @@ -351,6 +352,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( @@ -391,8 +397,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: crop_width=self.optimization_config.augmentation_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += SingleInstanceConfidenceMapGenerator( sigma=self.single_instance_confmap_head.sigma, output_stride=self.single_instance_confmap_head.output_stride, @@ -483,6 +493,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( @@ -529,8 +544,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: crop_width=self.optimization_config.augmentation_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, @@ -637,6 +656,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, @@ -674,8 +698,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: self.optimization_config.augmentation_config ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, @@ -768,6 +796,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( @@ -809,8 +842,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: crop_width=aug_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += MultiConfidenceMapGenerator( sigma=self.confmaps_head.sigma, output_stride=self.confmaps_head.output_stride, diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 8998ee7ee..d2e04f2cb 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -3,7 +3,7 @@ import numpy as np import tensorflow as tf import attr -from typing import Text, Optional, List, Sequence, Union +from typing import Text, Optional, List, Sequence, Union, Tuple import sleap @@ -93,6 +93,19 @@ def videos(self) -> List[sleap.Video]: """Return the list of videos that `video_ind` in examples match up with.""" return self.labels.videos + @property + def max_height_and_width(self) -> Tuple[int, int]: + return max(video.shape[1] for video in self.videos), max( + video.shape[2] for video in self.videos + ) + + @property + def is_from_multi_size_videos(self) -> bool: + return ( + len(set(v.shape[1] for v in self.videos)) > 1 + or len(set(v.shape[2] for v in self.videos)) > 1 + ) + def make_dataset( self, ds_index: Optional[tf.data.Dataset] = None ) -> tf.data.Dataset: @@ -125,10 +138,10 @@ def make_dataset( "skeleton_inds": Tensor of shape (n_instances,) of dtype tf.int32 that specifies the index of the skeleton used for each instance. """ - # Grab an image to test for the dtype. - test_lf = self.labels[0] - test_image = tf.convert_to_tensor(test_lf.image) - image_dtype = test_image.dtype + # Grab the first image to capture dtype and number of color channels. + first_image = tf.convert_to_tensor(self.labels[0].image) + image_dtype = first_image.dtype + image_num_channels = first_image.shape[-1] def py_fetch_lf(ind): """Local function that will not be autographed.""" @@ -167,7 +180,13 @@ def fetch_lf(ind): [ind], [image_dtype, tf.int32, tf.float32, tf.int32, tf.int64, tf.int32], ) - image = tf.ensure_shape(image, test_image.shape) + + # Ensure shape with constant or variable height/width, based on whether or not the videos have mixed sizes + if self.is_from_multi_size_videos: + image = tf.ensure_shape(image, (None, None, image_num_channels)) + else: + image = tf.ensure_shape(image, first_image.shape) + instances = tf.ensure_shape(instances, tf.TensorShape([None, None, 2])) skeleton_inds = tf.ensure_shape(skeleton_inds, tf.TensorShape([None])) diff --git a/sleap/nn/data/resizing.py b/sleap/nn/data/resizing.py index 95ff5cfc4..56a2c0315 100644 --- a/sleap/nn/data/resizing.py +++ b/sleap/nn/data/resizing.py @@ -251,6 +251,182 @@ def resize(example): return ds_output +@attr.s(auto_attribs=True) +class SizeMatcher: + """Data transformer that ensures output images have uniform shape by resizing/padding smaller images. + + Attributes: + image_key: String name of the key containing the images to resize. + scale_key: String name of the key containing the scale of the images. + points_key: String name of the key containing points to adjust for the resizing + operation. + keep_full_image: If True, keeps the (original size) full image in the examples. + This is useful for multi-scale inference. + full_image_key: String name of the key containing the full images. + max_image_height: int The target height to which all smaller images will be resized/padded to. + max_image_width: int The target width to which all smaller images will be resized/padded to. + """ + + image_key: Text = "image" + scale_key: Text = "scale" + points_key: Optional[Text] = "instances" + keep_full_image: bool = False + full_image_key: Text = "full_image" + max_image_height: int = None + max_image_width: int = None + + @classmethod + def from_config( + cls, + config: PreprocessingConfig, + provider: Optional = None, + update_config: bool = True, + image_key: Text = "image", + scale_key: Text = "scale", + keep_full_image: bool = False, + full_image_key: Text = "full_image", + points_key: Optional[Text] = "instances", + ) -> "SizeMatcher": + """Build an instance of this class from configuration. + + Args: + config: An `PreprocessingConfig` instance with the desired parameters. If + `config.resize_and_pad_to_target` is True and 'target_height' / 'target_width' are not set, provider + needs to be set that implements 'max_height_and_width'. + provider: Data provider. + update_config: If True, the input model configuration will be updated with + values inferred from other fields. + image_key: String name of the key containing the images to resize. + scale_key: String name of the key containing the scale of the images. + pad_to_stride: An integer specifying the `pad_to_stride` if + `config.pad_to_stride` is not an explicit integer (e.g., set to None). + keep_full_image: If True, keeps the (original size) full image in the + examples. This is useful for multi-scale inference. + full_image_key: String name of the key containing the full images. + points_key: String name of the key containing points to adjust for the + resizing operation. + Returns: + An instance of this class. + + Raises: + ValueError: If `provider` is not set or does not implement `max_height_and_width`. + """ + if config.resize_and_pad_to_target: + if config.target_height is not None and config.target_width is not None: + max_height = config.target_height + max_width = config.target_width + else: + try: + max_height, max_width = provider.max_height_and_width + except: + raise ValueError( + "target_height / target_width could not be determined" + ) + if update_config: + config.target_height = max_height + config.target_width = max_width + else: + max_height, max_width = None, None + + return cls( + image_key=image_key, + points_key=points_key, + scale_key=scale_key, + keep_full_image=keep_full_image, + full_image_key=full_image_key, + max_image_height=max_height, + max_image_width=max_width, + ) + + @property + def input_keys(self) -> List[Text]: + """Return the keys that incoming elements are expected to have.""" + input_keys = [self.image_key, self.scale_key] + if self.points_key is not None: + input_keys.append(self.points_key) + return input_keys + + @property + def output_keys(self) -> List[Text]: + """Return the keys that outgoing elements will have.""" + output_keys = self.input_keys + if self.keep_full_image: + output_keys.append(self.full_image_key) + return output_keys + + def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset: + """Transform a dataset with potentially different size images into one with equal sized images. + + Args: + ds_input: A dataset with the image specified in the `image_key` attribute, + points specified in the `points_key` attribute, and the "scale" key for + tracking scaling transformations. + + Returns: + A `tf.data.Dataset` with elements containing the same images and points of equal size. + + If the `keep_full_image` attribute is True, a key specified by + `full_image_key` will be added with the to the example containing the image + before any processing. + """ + + # mapping function: match to max height width by resizing and padding bottom/right accordingly + def resize_and_pad(example): + image = example[self.image_key] + if self.keep_full_image: + example[self.full_image_key] = image + + current_shape = tf.shape(image) + + # Only apply this transform if image shape differs from target + if ( + current_shape[-3] != self.max_image_height + or current_shape[-2] != self.max_image_width + ): + # Calculate target height and width for resizing the image (no padding yet) + hratio = self.max_image_height / tf.cast(current_shape[-3], tf.float32) + wratio = self.max_image_width / tf.cast(current_shape[-2], tf.float32) + if hratio > wratio: + # The bottleneck is width, scale to fit width first then pad to height + target_height = tf.cast( + tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32 + ) + target_width = self.max_image_width + example[self.scale_key] = example[self.scale_key] * wratio + else: + # The bottleneck is height, scale to fit height first then pad to width + target_height = self.max_image_height + target_width = tf.cast( + tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32 + ) + example[self.scale_key] = example[self.scale_key] * hratio + # Resize the image to fill one of the dimensions by preserving aspect ratio + image = tf.image.resize_with_pad( + image, target_height=target_height, target_width=target_width + ) + # Pad the image on bottom/right with zeroes to match specified dimensions + image = tf.image.pad_to_bounding_box( + image, + offset_height=0, + offset_width=0, + target_height=self.max_image_height, + target_width=self.max_image_width, + ) + example[self.image_key] = tf.cast(image, example[self.image_key].dtype) + # Scale the instance points accordingly + if self.points_key: + example[self.points_key] = ( + example[self.points_key] * example[self.scale_key] + ) + + return example + + ds_output = ds_input.map( + resize_and_pad, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + return ds_output + + @attr.s(auto_attribs=True) class PointsRescaler: """Transformer to apply or invert scaling operations on points.""" diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index 114259b77..7b23c5b39 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -11,6 +11,8 @@ def test_labels_reader(min_labels): labels_reader = providers.LabelsReader.from_user_instances(min_labels) ds = labels_reader.make_dataset() + assert not labels_reader.is_from_multi_size_videos + example = next(iter(ds)) assert len(labels_reader) == 1 @@ -47,6 +49,8 @@ def test_labels_reader_no_visible_points(min_labels): labels_reader = providers.LabelsReader.from_user_instances(min_labels) ds = labels_reader.make_dataset() + assert not labels_reader.is_from_multi_size_videos + example = next(iter(ds)) # There should be two instances in the labels dataset @@ -134,3 +138,49 @@ def test_video_reader_hdf5(): assert example["raw_image_size"].dtype == tf.int32 np.testing.assert_array_equal(example["raw_image_size"], (512, 512, 1)) + + +def test_labels_reader_multi_size(): + # Create some fake data using two different size videos. + skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"]) + labels = sleap.Labels( + [ + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_SMALL_ROBOT_MP4_FILE, grayscale=True + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_H5_FILE, dataset="/box", input_format="channels_first" + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + ] + ) + + # Create a loader for those labels. + labels_reader = providers.LabelsReader(labels) + ds = labels_reader.make_dataset() + ds_iter = iter(ds) + + # Check LabelReader can provide different shapes of individual samples + assert next(ds_iter)["image"].shape == (320, 560, 1) + assert next(ds_iter)["image"].shape == (512, 512, 1) + + # Check util functions + h, w = labels_reader.max_height_and_width + assert h == 512 + assert w == 560 + assert labels_reader.is_from_multi_size_videos diff --git a/tests/nn/data/test_resizing.py b/tests/nn/data/test_resizing.py index 6510dc0a6..891bbb189 100644 --- a/tests/nn/data/test_resizing.py +++ b/tests/nn/data/test_resizing.py @@ -5,8 +5,15 @@ use_cpu_only() # hide GPUs for test +import sleap +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import resizing from sleap.nn.data import providers +from sleap.nn.data.resizing import SizeMatcher + +from tests.fixtures.videos import TEST_H5_FILE, TEST_SMALL_ROBOT_MP4_FILE def test_find_padding_for_stride(): @@ -117,3 +124,88 @@ def test_resizer_from_config(): resizer = resizing.Resizer.from_config( config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=None) ) + + +def test_size_matcher(): + # Create some fake data using two different size videos. + skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"]) + labels = sleap.Labels( + [ + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_SMALL_ROBOT_MP4_FILE, grayscale=True + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_H5_FILE, dataset="/box", input_format="channels_first" + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + ] + ) + + # Create a loader for those labels. + labels_reader = providers.LabelsReader(labels) + ds = labels_reader.make_dataset() + ds_iter = iter(ds) + assert next(ds_iter)["image"].shape == (320, 560, 1) + assert next(ds_iter)["image"].shape == (512, 512, 1) + + def check_padding(image, from_y, to_y, from_x, to_x): + for y in range(from_y, to_y): + for x in range(from_x, to_x): + assert image[y][x] == 0 + + # Check SizeMatcher when target dims is not strictly larger than actual image dims + size_matcher = SizeMatcher(max_image_height=560, max_image_width=560) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (560, 560, 1) + # padding should be on the bottom + check_padding(im1, 321, 560, 0, 560) + im2 = next(transform_iter)["image"] + assert im2.shape == (560, 560, 1) + + # Variant 2 + size_matcher = SizeMatcher(max_image_height=320, max_image_width=560) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (320, 560, 1) + im2 = next(transform_iter)["image"] + assert im2.shape == (320, 560, 1) + # padding should be on the right + check_padding(im2, 0, 320, 321, 560) + + # Check SizeMatcher when target is 'max' in both dimensions + size_matcher = SizeMatcher(max_image_height=512, max_image_width=560) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (512, 560, 1) + # Check padding is on the bottom + check_padding(im1, 320, 512, 0, 560) + im2 = next(transform_iter)["image"] + assert im2.shape == (512, 560, 1) + # Check padding is on the right + check_padding(im2, 0, 512, 512, 560) + + # Check SizeMatcher when target is larger in both dimensions + size_matcher = SizeMatcher(max_image_height=750, max_image_width=750) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (750, 750, 1) + # Check padding is on the bottom + check_padding(im1, 700, 750, 0, 750) + im2 = next(transform_iter)["image"] + assert im2.shape == (750, 750, 1)