Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symmetry-aware flip augmentation #455

Merged
merged 9 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions sleap/config/training_editor_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,18 @@ augmentation:
label: Brightness Max Val
name: optimization.augmentation_config.brightness_max_val
type: double
- name: optimization.augmentation_config.random_flip
label: Random flip
help: 'Randomly reflect images and instances. IMPORTANT: Left/right symmetric nodes must be indicated in the skeleton or this will lead to incorrect results!'
type: bool
default: false
- name: optimization.augmentation_config.flip_horizontal
label: Flip left/right
help: Flip images horizontally when randomly reflecting. If unchecked, flipping will
reflect images up/down.
type: bool
default: true

optimization:
- default: 8
help: Number of examples per minibatch, i.e., a single step of training. Higher
Expand Down
93 changes: 50 additions & 43 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ def _node_to_index(self, node: Union[str, Node]) -> int:
return self.skeleton.node_to_index(node)

def __getitem__(
self, node: Union[List[Union[str, Node]], Union[str, Node]]
) -> Union[List[Point], Point]:
self, node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray]
) -> Union[List[Point], Point, np.ndarray]:
"""
Get the Points associated with particular skeleton node(s).

Expand All @@ -552,24 +552,28 @@ def __getitem__(
to each node.

"""

# If the node is a list of nodes, use get item recursively and return a list of _points.
if type(node) is list:
ret_list = []
# If the node is a list of nodes, use get item recursively and return a list of
# _points.
if isinstance(node, (list, tuple, np.ndarray)):
pts = []
for n in node:
ret_list.append(self.__getitem__(n))
pts.append(self.__getitem__(n))

return ret_list
if isinstance(node, np.ndarray):
return np.array([[pt.x, pt.y] for pt in pts])
else:
return pts

try:
node = self._node_to_index(node)
return self._points[node]
except ValueError:
raise KeyError(
f"The underlying skeleton ({self.skeleton}) has no node '{node}'"
)
if isinstance(node, (Node, str)):
try:
node = self._node_to_index(node)
except ValueError:
raise KeyError(
f"The underlying skeleton ({self.skeleton}) has no node '{node}'"
)
return self._points[node]

def __contains__(self, node: Union[str, Node]) -> bool:
def __contains__(self, node: Union[str, Node, int]) -> bool:
"""
Whether this instance has a point with the specified node.

Expand All @@ -584,18 +588,18 @@ def __contains__(self, node: Union[str, Node]) -> bool:
if isinstance(node, Node):
node = node.name

if node not in self.skeleton:
return False

node_idx = self._node_to_index(node)
if isinstance(node, str):
if node not in self.skeleton:
return False
node = self._node_to_index(node)

# If the points are nan, then they haven't been allocated.
return not self._points[node_idx].isnan()
return not self._points[node].isnan()

def __setitem__(
self,
node: Union[List[Union[str, Node]], Union[str, Node]],
value: Union[List[Point], Point],
node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray],
value: Union[List[Point], Point, np.ndarray],
):
"""
Set the point(s) for given node(s).
Expand All @@ -612,31 +616,34 @@ def __setitem__(
Returns:
None
"""

# Make sure node and value, if either are lists, are of compatible size
if type(node) is not list and type(value) is list and len(value) != 1:
raise IndexError(
"Node list for indexing must be same length and value list."
)

if type(node) is list and type(value) is not list and len(node) != 1:
raise IndexError(
"Node list for indexing must be same length and value list."
)
if isinstance(node, (list, np.ndarray)):
if not isinstance(value, (list, np.ndarray)) or len(value) != len(node):
raise IndexError(
"Node list for indexing must be same length and value list."
)

# If we are dealing with lists, do multiple assignment recursively, this should be ok because
# skeletons and instances are small.
if type(node) is list:
for n, v in zip(node, value):
self.__setitem__(n, v)
else:
try:
node_idx = self._node_to_index(node)
self._points[node_idx] = value
except ValueError:
raise KeyError(
f"The underlying skeleton ({self.skeleton}) has no node '{node}'"
)
if isinstance(node, (Node, str)):
try:
node_idx = self._node_to_index(node)
except ValueError:
raise KeyError(
f"The skeleton ({self.skeleton}) has no node '{node}'."
)
else:
node_idx = node

if not isinstance(value, Point):
if hasattr(value, "__len__") and len(value) == 2:
value = Point(x=value[0], y=value[1])
else:
raise ValueError(
"Instance point values must be (x, y) coordinates."
)
self._points[node_idx] = value

def __delitem__(self, node: Union[str, Node]):
"""
Expand Down
7 changes: 7 additions & 0 deletions sleap/nn/config/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class AugmentationConfig:
the augmentations above.
random_crop_width: Width of random crops.
random_crop_height: Height of random crops.
random_flip: If `True`, images will be randomly reflected. The coordinates of
the instances will be adjusted accordingly. Body parts that are left/right
symmetric must be marked on the skeleton in order to be swapped correctly.
flip_horizontal: If `True`, flip images left/right when randomly reflecting
them. If `False`, flipping is down up/down instead.
"""

rotate: bool = False
Expand All @@ -78,6 +83,8 @@ class AugmentationConfig:
random_crop: bool = False
random_crop_height: int = 256
random_crop_width: int = 256
random_flip: bool = False
flip_horizontal: bool = True


@attr.s(auto_attribs=True)
Expand Down
186 changes: 185 additions & 1 deletion sleap/nn/data/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,109 @@
if hasattr(numpy.random, "_bit_generator"):
numpy.random.bit_generator = numpy.random._bit_generator

import sleap
import numpy as np
import tensorflow as tf
import attr
from typing import List, Text
from typing import List, Text, Optional
import imgaug as ia
import imgaug.augmenters as iaa
from sleap.nn.config import AugmentationConfig
from sleap.nn.data.instance_cropping import crop_bboxes


def flip_instances_lr(
instances: tf.Tensor, img_width: int, symmetric_inds: Optional[tf.Tensor] = None
) -> tf.Tensor:
"""Flip a set of instance points horizontally with symmetric node adjustment.

Args:
instances: Instance points as a `tf.Tensor` of shape `(n_instances, n_nodes, 2)`
and dtype `tf.float32`.
img_width: Width of image in the same units as `instances`.
symmetric_inds: Indices of symmetric pairs of nodes as a `tf.Tensor` of shape
`(n_symmetries, 2)` and dtype `tf.int32`. Each row contains the indices of
nodes that are mirror symmetric, e.g., left/right body parts. The ordering
of the list or which node comes first (e.g., left/right vs right/left) does
not matter. Each pair of nodes will be swapped to account for the
reflection if this is not `None` (the default).

Returns:
The instance points with x-coordinates flipped horizontally.
"""
instances = (tf.cast([[[img_width - 1, 0]]], tf.float32) - instances) * tf.cast(
[[[1, -1]]], tf.float32
)

if symmetric_inds is not None:
n_instances = tf.shape(instances)[0]
n_symmetries = tf.shape(symmetric_inds)[0]

inst_inds = tf.reshape(tf.repeat(tf.range(n_instances), n_symmetries), [-1, 1])
sym_inds1 = tf.reshape(tf.gather(symmetric_inds, 0, axis=1), [-1, 1])
sym_inds2 = tf.reshape(tf.gather(symmetric_inds, 1, axis=1), [-1, 1])

inst_inds = tf.cast(inst_inds, tf.int32)
sym_inds1 = tf.cast(sym_inds1, tf.int32)
sym_inds2 = tf.cast(sym_inds2, tf.int32)

subs1 = tf.concat([inst_inds, tf.tile(sym_inds1, [n_instances, 1])], axis=1)
subs2 = tf.concat([inst_inds, tf.tile(sym_inds2, [n_instances, 1])], axis=1)

pts1 = tf.gather_nd(instances, subs1)
pts2 = tf.gather_nd(instances, subs2)
instances = tf.tensor_scatter_nd_update(instances, subs1, pts2)
instances = tf.tensor_scatter_nd_update(instances, subs2, pts1)

return instances


def flip_instances_ud(
instances: tf.Tensor, img_height: int, symmetric_inds: Optional[tf.Tensor] = None
) -> tf.Tensor:
"""Flip a set of instance points vertically with symmetric node adjustment.

Args:
instances: Instance points as a `tf.Tensor` of shape `(n_instances, n_nodes, 2)`
and dtype `tf.float32`.
img_height: Height of image in the same units as `instances`.
symmetric_inds: Indices of symmetric pairs of nodes as a `tf.Tensor` of shape
`(n_symmetries, 2)` and dtype `tf.int32`. Each row contains the indices of
nodes that are mirror symmetric, e.g., left/right body parts. The ordering
of the list or which node comes first (e.g., left/right vs right/left) does
not matter. Each pair of nodes will be swapped to account for the
reflection if this is not `None` (the default).

Returns:
The instance points with y-coordinates flipped horizontally.
"""
instances = (tf.cast([[[0, img_height - 1]]], tf.float32) - instances) * tf.cast(
[[[-1, 1]]], tf.float32
)

if symmetric_inds is not None:
n_instances = tf.shape(instances)[0]
n_symmetries = tf.shape(symmetric_inds)[0]

inst_inds = tf.reshape(tf.repeat(tf.range(n_instances), n_symmetries), [-1, 1])
sym_inds1 = tf.reshape(tf.gather(symmetric_inds, 0, axis=1), [-1, 1])
sym_inds2 = tf.reshape(tf.gather(symmetric_inds, 1, axis=1), [-1, 1])

inst_inds = tf.cast(inst_inds, tf.int32)
sym_inds1 = tf.cast(sym_inds1, tf.int32)
sym_inds2 = tf.cast(sym_inds2, tf.int32)

subs1 = tf.concat([inst_inds, tf.tile(sym_inds1, [n_instances, 1])], axis=1)
subs2 = tf.concat([inst_inds, tf.tile(sym_inds2, [n_instances, 1])], axis=1)

pts1 = tf.gather_nd(instances, subs1)
pts2 = tf.gather_nd(instances, subs2)
instances = tf.tensor_scatter_nd_update(instances, subs1, pts2)
instances = tf.tensor_scatter_nd_update(instances, subs2, pts1)

return instances


@attr.s(auto_attribs=True)
class ImgaugAugmenter:
"""Data transformer based on the `imgaug` library.
Expand Down Expand Up @@ -249,3 +342,94 @@ def random_crop(ex):
return ex

return input_ds.map(random_crop)


@attr.s(auto_attribs=True)
class RandomFlipper:
"""Data transformer for applying random flipping to input images.

This class can generate a `tf.data.Dataset` from an existing one that generates
image and instance data. Elements of the output dataset will have random horizontal
flips applied.

Attributes:
symmetric_inds: Indices of symmetric pairs of nodes as a an array of shape
`(n_symmetries, 2)`. Each row contains the indices of nodes that are mirror
symmetric, e.g., left/right body parts. The ordering of the list or which
node comes first (e.g., left/right vs right/left) does not matter. Each pair
of nodes will be swapped to account for the reflection if this is not `None`
(the default).
horizontal: If `True` (the default), flips are applied horizontally instead of
vertically.
probability: The probability that the augmentation should be applied.
"""

symmetric_inds: Optional[np.ndarray] = None
horizontal: bool = True
probability: float = 0.5

@classmethod
def from_skeleton(
cls, skeleton: sleap.Skeleton, horizontal: bool = True, probability: float = 0.5
) -> "RandomFlipper":
"""Create an instance of `RandomFlipper` from a skeleton.

Args:
skeleton: A `sleap.Skeleton` that may define symmetric nodes.
horizontal: If `True` (the default), flips are applied horizontally instead
of vertically.
probability: The probability that the augmentation should be applied.

Returns:
An instance of `RandomFlipper`.
"""
return cls(
symmetric_inds=skeleton.symmetric_inds,
horizontal=horizontal,
probability=probability,
)

@property
def input_keys(self):
return ["image", "instances"]

@property
def output_keys(self):
return self.input_keys

def transform_dataset(self, input_ds: tf.data.Dataset):
"""Create a `tf.data.Dataset` with elements containing augmented data.

Args:
input_ds: A dataset with elements that contain the keys `"image"` and
`"instances"`. This is typically raw data from a data provider.

Returns:
A `tf.data.Dataset` with the same keys as the input, but with images and
instance points updated with the applied random flip.
"""
symmetric_inds = self.symmetric_inds
if symmetric_inds is not None:
symmetric_inds = np.array(symmetric_inds)
if len(symmetric_inds) == 0:
symmetric_inds = None

def random_flip(ex):
"""Apply random flip to an example."""
p = tf.random.uniform((), minval=0, maxval=1.0)
if p <= self.probability:
if self.horizontal:
img_width = tf.shape(ex["image"])[1]
ex["instances"] = flip_instances_lr(
ex["instances"], img_width, symmetric_inds=symmetric_inds
)
ex["image"] = tf.image.flip_left_right(ex["image"])
else:
img_height = tf.shape(ex["image"])[0]
ex["instances"] = flip_instances_ud(
ex["instances"], img_height, symmetric_inds=symmetric_inds
)
ex["image"] = tf.image.flip_up_down(ex["image"])
return ex

return input_ds.map(random_flip)
Loading