Skip to content

Commit

Permalink
from config
Browse files Browse the repository at this point in the history
  • Loading branch information
arie-matsliah committed Feb 2, 2021
1 parent f161897 commit acaa740
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def videos(self) -> List[sleap.Video]:
"""Return the list of videos that `video_ind` in examples match up with."""
return self.labels.videos

def max_video_height_and_width(self) -> Tuple[int, int]:
def max_height_and_width(self) -> Tuple[int, int]:
return max(video.shape[1] for video in self.videos), max(video.shape[2] for video in self.videos)

def make_dataset(
Expand Down
62 changes: 62 additions & 0 deletions sleap/nn/data/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,68 @@ class SizeMatcher:
max_image_height: int = None
max_image_width: int = None


@classmethod
def from_config(
cls,
config: PreprocessingConfig,
provider: Optional[Provider] = None,
update_config: bool = True,
image_key: Text = "image",
scale_key: Text = "scale",
keep_full_image: bool = False,
full_image_key: Text = "full_image",
points_key: Optional[Text] = "instances"
) -> "SizeMatcher":
"""Build an instance of this class from configuration.
Args:
config: An `PreprocessingConfig` instance with the desired parameters. If
`config.resize_and_pad_to_target` is True and 'target_height' / 'target_width' are not set, provider
needs to be set that implements 'max_height_and_width'.
provider: Data provider.
update_config: If True, the input model configuration will be updated with
values inferred from other fields.
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
pad_to_stride: An integer specifying the `pad_to_stride` if
`config.pad_to_stride` is not an explicit integer (e.g., set to None).
keep_full_image: If True, keeps the (original size) full image in the
examples. This is useful for multi-scale inference.
full_image_key: String name of the key containing the full images.
points_key: String name of the key containing points to adjust for the
resizing operation.
Returns:
An instance of this class.
Raises:
ValueError: If `provider` is not set or does not implement `max_height_and_width`.
"""
if config.resize_and_pad_to_target:
if config.target_height is not None and config.target_width is not None:
max_height = config.target_height
max_width = config.target_width
else:
try:
max_height, max_width = provider.max_height_and_width()
except:
raise ValueError("target_height / target_width could not be determined")
if update_config:
config.target_height = max_height
config.target_width = max_width
else:
max_height, max_width = None, None

return cls(
image_key=image_key,
points_key=points_key,
scale_key=scale_key,
keep_full_image=keep_full_image,
full_image_key=full_image_key,
max_image_height=max_height,
max_image_width=max_width,
)

@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,6 @@ def test_labels_reader_multi_size():
assert next(ds_iter)["image"].shape == (320, 560, 1)
assert next(ds_iter)["image"].shape == (512, 512, 1)

h, w = labels_reader.max_video_height_and_width()
h, w = labels_reader.max_height_and_width()
assert h == 512
assert w == 560

0 comments on commit acaa740

Please sign in to comment.