Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-size videos in data pipelines #440

Merged
merged 19 commits into from
Feb 3, 2021
10 changes: 10 additions & 0 deletions sleap/nn/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ class PreprocessingConfig:
max stride (typically 32). This padding will be ignored when instance
cropping inputs since the crop size should already be divisible by the
model's max stride.
resize_and_pad_to_target: If True, will resize and pad all images in the dataset to match target dimensions.
This is useful when preprocessing datasets with mixed image dimensions (from different video resolutions).
Aspect ratio is preserved, and padding applied (if needed) to bottom or right of image only.
target_height: Target image height for 'resize_and_pad_to_target'. When not explicitly provided, inferred as the
max image height from the dataset.
target_width: Target image width for 'resize_and_pad_to_target'. When not explicitly provided, inferred as the
max image width from the dataset.
"""

ensure_rgb: bool = False
Expand All @@ -88,6 +95,9 @@ class PreprocessingConfig:
)
input_scaling: float = 1.0
pad_to_stride: Optional[int] = None
resize_and_pad_to_target: bool = True
target_height: Optional[int] = None
target_width: Optional[int] = None


@attr.s(auto_attribs=True)
Expand Down
49 changes: 43 additions & 6 deletions sleap/nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
RandomCropper,
)
from sleap.nn.data.normalization import Normalizer
from sleap.nn.data.resizing import Resizer, PointsRescaler
from sleap.nn.data.resizing import Resizer, PointsRescaler, SizeMatcher
from sleap.nn.data.instance_centroids import InstanceCentroidFinder
from sleap.nn.data.instance_cropping import InstanceCropper, PredictedInstanceCropper
from sleap.nn.data.confidence_maps import (
Expand Down Expand Up @@ -68,6 +68,7 @@
RandomCropper,
Normalizer,
Resizer,
SizeMatcher,
InstanceCentroidFinder,
InstanceCropper,
MultiConfidenceMapGenerator,
Expand Down Expand Up @@ -123,7 +124,7 @@ def from_blocks(
"""Create a pipeline from a sequence of providers and transformers.

Args:
sequence: List or tuple of providers and transformer instances.
blocks: List or tuple of providers and transformer instances.

Returns:
An instantiated pipeline with all blocks chained.
Expand Down Expand Up @@ -351,6 +352,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
"""
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.optimization_config.augmentation_config.random_crop:
pipeline += RandomCropper(
Expand Down Expand Up @@ -391,8 +397,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
crop_width=self.optimization_config.augmentation_config.random_crop_width,
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)

pipeline += SingleInstanceConfidenceMapGenerator(
sigma=self.single_instance_confmap_head.sigma,
output_stride=self.single_instance_confmap_head.output_stride,
Expand Down Expand Up @@ -483,6 +493,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
"""
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.optimization_config.augmentation_config.random_crop:
pipeline += RandomCropper(
Expand Down Expand Up @@ -529,8 +544,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
crop_width=self.optimization_config.augmentation_config.random_crop_width,
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)

pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
skeletons=self.data_config.labels.skeletons,
Expand Down Expand Up @@ -637,6 +656,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
"""
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)
pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
Expand Down Expand Up @@ -674,8 +698,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
self.optimization_config.augmentation_config
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)

pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
skeletons=self.data_config.labels.skeletons,
Expand Down Expand Up @@ -768,6 +796,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
"""
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.optimization_config.augmentation_config.random_crop:
pipeline += RandomCropper(
Expand Down Expand Up @@ -809,8 +842,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
crop_width=aug_config.random_crop_width,
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += Resizer.from_config(self.data_config.preprocessing)

pipeline += MultiConfidenceMapGenerator(
sigma=self.confmaps_head.sigma,
output_stride=self.confmaps_head.output_stride,
Expand Down
31 changes: 25 additions & 6 deletions sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import tensorflow as tf
import attr
from typing import Text, Optional, List, Sequence, Union
from typing import Text, Optional, List, Sequence, Union, Tuple
import sleap


Expand Down Expand Up @@ -93,6 +93,19 @@ def videos(self) -> List[sleap.Video]:
"""Return the list of videos that `video_ind` in examples match up with."""
return self.labels.videos

@property
def max_height_and_width(self) -> Tuple[int, int]:
return max(video.shape[1] for video in self.videos), max(
video.shape[2] for video in self.videos
)

@property
def is_from_multi_size_videos(self) -> bool:
return (
len(set(v.shape[1] for v in self.videos)) > 1
or len(set(v.shape[2] for v in self.videos)) > 1
)

def make_dataset(
self, ds_index: Optional[tf.data.Dataset] = None
) -> tf.data.Dataset:
Expand Down Expand Up @@ -125,10 +138,10 @@ def make_dataset(
"skeleton_inds": Tensor of shape (n_instances,) of dtype tf.int32 that
specifies the index of the skeleton used for each instance.
"""
# Grab an image to test for the dtype.
test_lf = self.labels[0]
test_image = tf.convert_to_tensor(test_lf.image)
image_dtype = test_image.dtype
# Grab the first image to capture dtype and number of color channels.
first_image = tf.convert_to_tensor(self.labels[0].image)
image_dtype = first_image.dtype
image_num_channels = first_image.shape[-1]

def py_fetch_lf(ind):
"""Local function that will not be autographed."""
Expand Down Expand Up @@ -167,7 +180,13 @@ def fetch_lf(ind):
[ind],
[image_dtype, tf.int32, tf.float32, tf.int32, tf.int64, tf.int32],
)
image = tf.ensure_shape(image, test_image.shape)

# Ensure shape with constant or variable height/width, based on whether or not the videos have mixed sizes
if self.is_from_multi_size_videos:
image = tf.ensure_shape(image, (None, None, image_num_channels))
else:
image = tf.ensure_shape(image, first_image.shape)

instances = tf.ensure_shape(instances, tf.TensorShape([None, None, 2]))
skeleton_inds = tf.ensure_shape(skeleton_inds, tf.TensorShape([None]))

Expand Down
176 changes: 176 additions & 0 deletions sleap/nn/data/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,182 @@ def resize(example):
return ds_output


@attr.s(auto_attribs=True)
class SizeMatcher:
"""Data transformer that ensures output images have uniform shape by resizing/padding smaller images.

Attributes:
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
points_key: String name of the key containing points to adjust for the resizing
operation.
keep_full_image: If True, keeps the (original size) full image in the examples.
This is useful for multi-scale inference.
full_image_key: String name of the key containing the full images.
max_image_height: int The target height to which all smaller images will be resized/padded to.
max_image_width: int The target width to which all smaller images will be resized/padded to.
"""

image_key: Text = "image"
scale_key: Text = "scale"
points_key: Optional[Text] = "instances"
keep_full_image: bool = False
full_image_key: Text = "full_image"
max_image_height: int = None
max_image_width: int = None

@classmethod
def from_config(
cls,
config: PreprocessingConfig,
provider: Optional = None,
update_config: bool = True,
image_key: Text = "image",
scale_key: Text = "scale",
keep_full_image: bool = False,
full_image_key: Text = "full_image",
points_key: Optional[Text] = "instances",
) -> "SizeMatcher":
"""Build an instance of this class from configuration.

Args:
config: An `PreprocessingConfig` instance with the desired parameters. If
`config.resize_and_pad_to_target` is True and 'target_height' / 'target_width' are not set, provider
needs to be set that implements 'max_height_and_width'.
provider: Data provider.
update_config: If True, the input model configuration will be updated with
values inferred from other fields.
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
pad_to_stride: An integer specifying the `pad_to_stride` if
`config.pad_to_stride` is not an explicit integer (e.g., set to None).
keep_full_image: If True, keeps the (original size) full image in the
examples. This is useful for multi-scale inference.
full_image_key: String name of the key containing the full images.
points_key: String name of the key containing points to adjust for the
resizing operation.
Returns:
An instance of this class.

Raises:
ValueError: If `provider` is not set or does not implement `max_height_and_width`.
"""
if config.resize_and_pad_to_target:
if config.target_height is not None and config.target_width is not None:
max_height = config.target_height
max_width = config.target_width
else:
try:
max_height, max_width = provider.max_height_and_width
except:
raise ValueError(
"target_height / target_width could not be determined"
)
if update_config:
config.target_height = max_height
config.target_width = max_width
else:
max_height, max_width = None, None

return cls(
image_key=image_key,
points_key=points_key,
scale_key=scale_key,
keep_full_image=keep_full_image,
full_image_key=full_image_key,
max_image_height=max_height,
max_image_width=max_width,
)

@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
input_keys = [self.image_key, self.scale_key]
if self.points_key is not None:
input_keys.append(self.points_key)
return input_keys

@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
output_keys = self.input_keys
if self.keep_full_image:
output_keys.append(self.full_image_key)
return output_keys

def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Transform a dataset with potentially different size images into one with equal sized images.

Args:
ds_input: A dataset with the image specified in the `image_key` attribute,
points specified in the `points_key` attribute, and the "scale" key for
tracking scaling transformations.

Returns:
A `tf.data.Dataset` with elements containing the same images and points of equal size.

If the `keep_full_image` attribute is True, a key specified by
`full_image_key` will be added with the to the example containing the image
before any processing.
"""

# mapping function: match to max height width by resizing and padding bottom/right accordingly
def resize_and_pad(example):
image = example[self.image_key]
if self.keep_full_image:
example[self.full_image_key] = image

current_shape = tf.shape(image)

# Only apply this transform if image shape differs from target
if (
current_shape[-3] != self.max_image_height
or current_shape[-2] != self.max_image_width
):
# Calculate target height and width for resizing the image (no padding yet)
hratio = self.max_image_height / tf.cast(current_shape[-3], tf.float32)
wratio = self.max_image_width / tf.cast(current_shape[-2], tf.float32)
if hratio > wratio:
# The bottleneck is width, scale to fit width first then pad to height
target_height = tf.cast(
tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32
)
target_width = self.max_image_width
example[self.scale_key] = example[self.scale_key] * wratio
else:
# The bottleneck is height, scale to fit height first then pad to width
target_height = self.max_image_height
target_width = tf.cast(
tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32
)
example[self.scale_key] = example[self.scale_key] * hratio
# Resize the image to fill one of the dimensions by preserving aspect ratio
image = tf.image.resize_with_pad(
image, target_height=target_height, target_width=target_width
)
# Pad the image on bottom/right with zeroes to match specified dimensions
image = tf.image.pad_to_bounding_box(
image,
offset_height=0,
offset_width=0,
target_height=self.max_image_height,
target_width=self.max_image_width,
)
example[self.image_key] = tf.cast(image, example[self.image_key].dtype)
# Scale the instance points accordingly
if self.points_key:
example[self.points_key] = (
example[self.points_key] * example[self.scale_key]
)

return example

ds_output = ds_input.map(
resize_and_pad, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return ds_output


@attr.s(auto_attribs=True)
class PointsRescaler:
"""Transformer to apply or invert scaling operations on points."""
Expand Down
Loading