Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
arie-matsliah committed Feb 2, 2021
1 parent fecc751 commit 60d269e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
4 changes: 3 additions & 1 deletion sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def videos(self) -> List[sleap.Video]:
return self.labels.videos

def max_height_and_width(self) -> Tuple[int, int]:
return max(video.shape[1] for video in self.videos), max(video.shape[2] for video in self.videos)
return max(video.shape[1] for video in self.videos), max(
video.shape[2] for video in self.videos
)

def make_dataset(
self, ds_index: Optional[tf.data.Dataset] = None
Expand Down
35 changes: 21 additions & 14 deletions sleap/nn/data/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ class SizeMatcher:
max_image_height: int = None
max_image_width: int = None


@classmethod
def from_config(
cls,
Expand All @@ -286,7 +285,7 @@ def from_config(
scale_key: Text = "scale",
keep_full_image: bool = False,
full_image_key: Text = "full_image",
points_key: Optional[Text] = "instances"
points_key: Optional[Text] = "instances",
) -> "SizeMatcher":
"""Build an instance of this class from configuration.
Expand Down Expand Up @@ -320,7 +319,9 @@ def from_config(
try:
max_height, max_width = provider.max_height_and_width()
except:
raise ValueError("target_height / target_width could not be determined")
raise ValueError(
"target_height / target_width could not be determined"
)
if update_config:
config.target_height = max_height
config.target_width = max_width
Expand Down Expand Up @@ -378,42 +379,48 @@ def resize_and_pad(example):
current_shape = tf.shape(image)

# Only apply this transform if image shape differs from target
if current_shape[-3] != self.max_image_height or current_shape[-2] != self.max_image_width:
if (
current_shape[-3] != self.max_image_height
or current_shape[-2] != self.max_image_width
):
# Calculate target height and width for resizing the image (no padding yet)
hratio = self.max_image_height / tf.cast(current_shape[-3], tf.float32)
wratio = self.max_image_width / tf.cast(current_shape[-2], tf.float32)
if hratio > wratio:
# The bottleneck is width, scale to fit width first then pad to height
target_height=tf.cast(tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32)
target_width=self.max_image_width
target_height = tf.cast(
tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32
)
target_width = self.max_image_width
example[self.scale_key] = example[self.scale_key] * wratio
else:
# The bottleneck is height, scale to fit height first then pad to width
target_height=self.max_image_height
target_width=tf.cast(tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32)
target_height = self.max_image_height
target_width = tf.cast(
tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32
)
example[self.scale_key] = example[self.scale_key] * hratio
# Resize the image to fill one of the dimensions by preserving aspect ratio
image = tf.image.resize_with_pad(
image,
target_height=target_height,
target_width=target_width
image, target_height=target_height, target_width=target_width
)
# Pad the image on bottom/right with zeroes to match specified dimensions
image = tf.image.pad_to_bounding_box(
image,
offset_height=0,
offset_width=0,
target_height=self.max_image_height,
target_width=self.max_image_width
target_width=self.max_image_width,
)
example[self.image_key] = tf.cast(image, example[self.image_key].dtype)
# Scale the instance points accordingly
if self.points_key:
example[self.points_key] = example[self.points_key] * example[self.scale_key]
example[self.points_key] = (
example[self.points_key] * example[self.scale_key]
)

return example


ds_output = ds_input.map(
resize_and_pad, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
Expand Down
1 change: 1 addition & 0 deletions tests/nn/data/test_resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def test_resizer_from_config():
config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=None)
)


def test_size_matcher():
# Create some fake data using two different size videos.
skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"])
Expand Down

0 comments on commit 60d269e

Please sign in to comment.