diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index eff107287..4aa05638f 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -94,7 +94,9 @@ def videos(self) -> List[sleap.Video]: return self.labels.videos def max_height_and_width(self) -> Tuple[int, int]: - return max(video.shape[1] for video in self.videos), max(video.shape[2] for video in self.videos) + return max(video.shape[1] for video in self.videos), max( + video.shape[2] for video in self.videos + ) def make_dataset( self, ds_index: Optional[tf.data.Dataset] = None diff --git a/sleap/nn/data/resizing.py b/sleap/nn/data/resizing.py index ec6efcd34..cc6217b68 100644 --- a/sleap/nn/data/resizing.py +++ b/sleap/nn/data/resizing.py @@ -275,7 +275,6 @@ class SizeMatcher: max_image_height: int = None max_image_width: int = None - @classmethod def from_config( cls, @@ -286,7 +285,7 @@ def from_config( scale_key: Text = "scale", keep_full_image: bool = False, full_image_key: Text = "full_image", - points_key: Optional[Text] = "instances" + points_key: Optional[Text] = "instances", ) -> "SizeMatcher": """Build an instance of this class from configuration. @@ -320,7 +319,9 @@ def from_config( try: max_height, max_width = provider.max_height_and_width() except: - raise ValueError("target_height / target_width could not be determined") + raise ValueError( + "target_height / target_width could not be determined" + ) if update_config: config.target_height = max_height config.target_width = max_width @@ -378,25 +379,30 @@ def resize_and_pad(example): current_shape = tf.shape(image) # Only apply this transform if image shape differs from target - if current_shape[-3] != self.max_image_height or current_shape[-2] != self.max_image_width: + if ( + current_shape[-3] != self.max_image_height + or current_shape[-2] != self.max_image_width + ): # Calculate target height and width for resizing the image (no padding yet) hratio = self.max_image_height / tf.cast(current_shape[-3], tf.float32) wratio = self.max_image_width / tf.cast(current_shape[-2], tf.float32) if hratio > wratio: # The bottleneck is width, scale to fit width first then pad to height - target_height=tf.cast(tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32) - target_width=self.max_image_width + target_height = tf.cast( + tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32 + ) + target_width = self.max_image_width example[self.scale_key] = example[self.scale_key] * wratio else: # The bottleneck is height, scale to fit height first then pad to width - target_height=self.max_image_height - target_width=tf.cast(tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32) + target_height = self.max_image_height + target_width = tf.cast( + tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32 + ) example[self.scale_key] = example[self.scale_key] * hratio # Resize the image to fill one of the dimensions by preserving aspect ratio image = tf.image.resize_with_pad( - image, - target_height=target_height, - target_width=target_width + image, target_height=target_height, target_width=target_width ) # Pad the image on bottom/right with zeroes to match specified dimensions image = tf.image.pad_to_bounding_box( @@ -404,16 +410,17 @@ def resize_and_pad(example): offset_height=0, offset_width=0, target_height=self.max_image_height, - target_width=self.max_image_width + target_width=self.max_image_width, ) example[self.image_key] = tf.cast(image, example[self.image_key].dtype) # Scale the instance points accordingly if self.points_key: - example[self.points_key] = example[self.points_key] * example[self.scale_key] + example[self.points_key] = ( + example[self.points_key] * example[self.scale_key] + ) return example - ds_output = ds_input.map( resize_and_pad, num_parallel_calls=tf.data.experimental.AUTOTUNE ) diff --git a/tests/nn/data/test_resizing.py b/tests/nn/data/test_resizing.py index 5f5b01d20..66c31ab61 100644 --- a/tests/nn/data/test_resizing.py +++ b/tests/nn/data/test_resizing.py @@ -123,6 +123,7 @@ def test_resizer_from_config(): config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=None) ) + def test_size_matcher(): # Create some fake data using two different size videos. skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"])