diff --git a/sleap/config/training_editor_form.yaml b/sleap/config/training_editor_form.yaml index 1a82280ef..44cafe008 100644 --- a/sleap/config/training_editor_form.yaml +++ b/sleap/config/training_editor_form.yaml @@ -533,6 +533,18 @@ augmentation: label: Brightness Max Val name: optimization.augmentation_config.brightness_max_val type: double +- name: optimization.augmentation_config.random_flip + label: Random flip + help: 'Randomly reflect images and instances. IMPORTANT: Left/right symmetric nodes must be indicated in the skeleton or this will lead to incorrect results!' + type: bool + default: false +- name: optimization.augmentation_config.flip_horizontal + label: Flip left/right + help: Flip images horizontally when randomly reflecting. If unchecked, flipping will + reflect images up/down. + type: bool + default: true + optimization: - default: 8 help: Number of examples per minibatch, i.e., a single step of training. Higher diff --git a/sleap/instance.py b/sleap/instance.py index 2a7750514..b32aa6d7e 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -534,8 +534,8 @@ def _node_to_index(self, node: Union[str, Node]) -> int: return self.skeleton.node_to_index(node) def __getitem__( - self, node: Union[List[Union[str, Node]], Union[str, Node]] - ) -> Union[List[Point], Point]: + self, node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray] + ) -> Union[List[Point], Point, np.ndarray]: """ Get the Points associated with particular skeleton node(s). @@ -552,24 +552,28 @@ def __getitem__( to each node. """ - - # If the node is a list of nodes, use get item recursively and return a list of _points. - if type(node) is list: - ret_list = [] + # If the node is a list of nodes, use get item recursively and return a list of + # _points. + if isinstance(node, (list, tuple, np.ndarray)): + pts = [] for n in node: - ret_list.append(self.__getitem__(n)) + pts.append(self.__getitem__(n)) - return ret_list + if isinstance(node, np.ndarray): + return np.array([[pt.x, pt.y] for pt in pts]) + else: + return pts - try: - node = self._node_to_index(node) - return self._points[node] - except ValueError: - raise KeyError( - f"The underlying skeleton ({self.skeleton}) has no node '{node}'" - ) + if isinstance(node, (Node, str)): + try: + node = self._node_to_index(node) + except ValueError: + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) + return self._points[node] - def __contains__(self, node: Union[str, Node]) -> bool: + def __contains__(self, node: Union[str, Node, int]) -> bool: """ Whether this instance has a point with the specified node. @@ -584,18 +588,18 @@ def __contains__(self, node: Union[str, Node]) -> bool: if isinstance(node, Node): node = node.name - if node not in self.skeleton: - return False - - node_idx = self._node_to_index(node) + if isinstance(node, str): + if node not in self.skeleton: + return False + node = self._node_to_index(node) # If the points are nan, then they haven't been allocated. - return not self._points[node_idx].isnan() + return not self._points[node].isnan() def __setitem__( self, - node: Union[List[Union[str, Node]], Union[str, Node]], - value: Union[List[Point], Point], + node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray], + value: Union[List[Point], Point, np.ndarray], ): """ Set the point(s) for given node(s). @@ -612,31 +616,34 @@ def __setitem__( Returns: None """ - # Make sure node and value, if either are lists, are of compatible size - if type(node) is not list and type(value) is list and len(value) != 1: - raise IndexError( - "Node list for indexing must be same length and value list." - ) - - if type(node) is list and type(value) is not list and len(node) != 1: - raise IndexError( - "Node list for indexing must be same length and value list." - ) + if isinstance(node, (list, np.ndarray)): + if not isinstance(value, (list, np.ndarray)) or len(value) != len(node): + raise IndexError( + "Node list for indexing must be same length and value list." + ) - # If we are dealing with lists, do multiple assignment recursively, this should be ok because - # skeletons and instances are small. - if type(node) is list: for n, v in zip(node, value): self.__setitem__(n, v) else: - try: - node_idx = self._node_to_index(node) - self._points[node_idx] = value - except ValueError: - raise KeyError( - f"The underlying skeleton ({self.skeleton}) has no node '{node}'" - ) + if isinstance(node, (Node, str)): + try: + node_idx = self._node_to_index(node) + except ValueError: + raise KeyError( + f"The skeleton ({self.skeleton}) has no node '{node}'." + ) + else: + node_idx = node + + if not isinstance(value, Point): + if hasattr(value, "__len__") and len(value) == 2: + value = Point(x=value[0], y=value[1]) + else: + raise ValueError( + "Instance point values must be (x, y) coordinates." + ) + self._points[node_idx] = value def __delitem__(self, node: Union[str, Node]): """ diff --git a/sleap/nn/config/optimization.py b/sleap/nn/config/optimization.py index 725526bba..cd43136b7 100644 --- a/sleap/nn/config/optimization.py +++ b/sleap/nn/config/optimization.py @@ -52,6 +52,11 @@ class AugmentationConfig: the augmentations above. random_crop_width: Width of random crops. random_crop_height: Height of random crops. + random_flip: If `True`, images will be randomly reflected. The coordinates of + the instances will be adjusted accordingly. Body parts that are left/right + symmetric must be marked on the skeleton in order to be swapped correctly. + flip_horizontal: If `True`, flip images left/right when randomly reflecting + them. If `False`, flipping is down up/down instead. """ rotate: bool = False @@ -78,6 +83,8 @@ class AugmentationConfig: random_crop: bool = False random_crop_height: int = 256 random_crop_width: int = 256 + random_flip: bool = False + flip_horizontal: bool = True @attr.s(auto_attribs=True) diff --git a/sleap/nn/data/augmentation.py b/sleap/nn/data/augmentation.py index 86088bfa5..186ea5db4 100644 --- a/sleap/nn/data/augmentation.py +++ b/sleap/nn/data/augmentation.py @@ -7,16 +7,109 @@ if hasattr(numpy.random, "_bit_generator"): numpy.random.bit_generator = numpy.random._bit_generator +import sleap import numpy as np import tensorflow as tf import attr -from typing import List, Text +from typing import List, Text, Optional import imgaug as ia import imgaug.augmenters as iaa from sleap.nn.config import AugmentationConfig from sleap.nn.data.instance_cropping import crop_bboxes +def flip_instances_lr( + instances: tf.Tensor, img_width: int, symmetric_inds: Optional[tf.Tensor] = None +) -> tf.Tensor: + """Flip a set of instance points horizontally with symmetric node adjustment. + + Args: + instances: Instance points as a `tf.Tensor` of shape `(n_instances, n_nodes, 2)` + and dtype `tf.float32`. + img_width: Width of image in the same units as `instances`. + symmetric_inds: Indices of symmetric pairs of nodes as a `tf.Tensor` of shape + `(n_symmetries, 2)` and dtype `tf.int32`. Each row contains the indices of + nodes that are mirror symmetric, e.g., left/right body parts. The ordering + of the list or which node comes first (e.g., left/right vs right/left) does + not matter. Each pair of nodes will be swapped to account for the + reflection if this is not `None` (the default). + + Returns: + The instance points with x-coordinates flipped horizontally. + """ + instances = (tf.cast([[[img_width - 1, 0]]], tf.float32) - instances) * tf.cast( + [[[1, -1]]], tf.float32 + ) + + if symmetric_inds is not None: + n_instances = tf.shape(instances)[0] + n_symmetries = tf.shape(symmetric_inds)[0] + + inst_inds = tf.reshape(tf.repeat(tf.range(n_instances), n_symmetries), [-1, 1]) + sym_inds1 = tf.reshape(tf.gather(symmetric_inds, 0, axis=1), [-1, 1]) + sym_inds2 = tf.reshape(tf.gather(symmetric_inds, 1, axis=1), [-1, 1]) + + inst_inds = tf.cast(inst_inds, tf.int32) + sym_inds1 = tf.cast(sym_inds1, tf.int32) + sym_inds2 = tf.cast(sym_inds2, tf.int32) + + subs1 = tf.concat([inst_inds, tf.tile(sym_inds1, [n_instances, 1])], axis=1) + subs2 = tf.concat([inst_inds, tf.tile(sym_inds2, [n_instances, 1])], axis=1) + + pts1 = tf.gather_nd(instances, subs1) + pts2 = tf.gather_nd(instances, subs2) + instances = tf.tensor_scatter_nd_update(instances, subs1, pts2) + instances = tf.tensor_scatter_nd_update(instances, subs2, pts1) + + return instances + + +def flip_instances_ud( + instances: tf.Tensor, img_height: int, symmetric_inds: Optional[tf.Tensor] = None +) -> tf.Tensor: + """Flip a set of instance points vertically with symmetric node adjustment. + + Args: + instances: Instance points as a `tf.Tensor` of shape `(n_instances, n_nodes, 2)` + and dtype `tf.float32`. + img_height: Height of image in the same units as `instances`. + symmetric_inds: Indices of symmetric pairs of nodes as a `tf.Tensor` of shape + `(n_symmetries, 2)` and dtype `tf.int32`. Each row contains the indices of + nodes that are mirror symmetric, e.g., left/right body parts. The ordering + of the list or which node comes first (e.g., left/right vs right/left) does + not matter. Each pair of nodes will be swapped to account for the + reflection if this is not `None` (the default). + + Returns: + The instance points with y-coordinates flipped horizontally. + """ + instances = (tf.cast([[[0, img_height - 1]]], tf.float32) - instances) * tf.cast( + [[[-1, 1]]], tf.float32 + ) + + if symmetric_inds is not None: + n_instances = tf.shape(instances)[0] + n_symmetries = tf.shape(symmetric_inds)[0] + + inst_inds = tf.reshape(tf.repeat(tf.range(n_instances), n_symmetries), [-1, 1]) + sym_inds1 = tf.reshape(tf.gather(symmetric_inds, 0, axis=1), [-1, 1]) + sym_inds2 = tf.reshape(tf.gather(symmetric_inds, 1, axis=1), [-1, 1]) + + inst_inds = tf.cast(inst_inds, tf.int32) + sym_inds1 = tf.cast(sym_inds1, tf.int32) + sym_inds2 = tf.cast(sym_inds2, tf.int32) + + subs1 = tf.concat([inst_inds, tf.tile(sym_inds1, [n_instances, 1])], axis=1) + subs2 = tf.concat([inst_inds, tf.tile(sym_inds2, [n_instances, 1])], axis=1) + + pts1 = tf.gather_nd(instances, subs1) + pts2 = tf.gather_nd(instances, subs2) + instances = tf.tensor_scatter_nd_update(instances, subs1, pts2) + instances = tf.tensor_scatter_nd_update(instances, subs2, pts1) + + return instances + + @attr.s(auto_attribs=True) class ImgaugAugmenter: """Data transformer based on the `imgaug` library. @@ -249,3 +342,94 @@ def random_crop(ex): return ex return input_ds.map(random_crop) + + +@attr.s(auto_attribs=True) +class RandomFlipper: + """Data transformer for applying random flipping to input images. + + This class can generate a `tf.data.Dataset` from an existing one that generates + image and instance data. Elements of the output dataset will have random horizontal + flips applied. + + Attributes: + symmetric_inds: Indices of symmetric pairs of nodes as a an array of shape + `(n_symmetries, 2)`. Each row contains the indices of nodes that are mirror + symmetric, e.g., left/right body parts. The ordering of the list or which + node comes first (e.g., left/right vs right/left) does not matter. Each pair + of nodes will be swapped to account for the reflection if this is not `None` + (the default). + horizontal: If `True` (the default), flips are applied horizontally instead of + vertically. + probability: The probability that the augmentation should be applied. + """ + + symmetric_inds: Optional[np.ndarray] = None + horizontal: bool = True + probability: float = 0.5 + + @classmethod + def from_skeleton( + cls, skeleton: sleap.Skeleton, horizontal: bool = True, probability: float = 0.5 + ) -> "RandomFlipper": + """Create an instance of `RandomFlipper` from a skeleton. + + Args: + skeleton: A `sleap.Skeleton` that may define symmetric nodes. + horizontal: If `True` (the default), flips are applied horizontally instead + of vertically. + probability: The probability that the augmentation should be applied. + + Returns: + An instance of `RandomFlipper`. + """ + return cls( + symmetric_inds=skeleton.symmetric_inds, + horizontal=horizontal, + probability=probability, + ) + + @property + def input_keys(self): + return ["image", "instances"] + + @property + def output_keys(self): + return self.input_keys + + def transform_dataset(self, input_ds: tf.data.Dataset): + """Create a `tf.data.Dataset` with elements containing augmented data. + + Args: + input_ds: A dataset with elements that contain the keys `"image"` and + `"instances"`. This is typically raw data from a data provider. + + Returns: + A `tf.data.Dataset` with the same keys as the input, but with images and + instance points updated with the applied random flip. + """ + symmetric_inds = self.symmetric_inds + if symmetric_inds is not None: + symmetric_inds = np.array(symmetric_inds) + if len(symmetric_inds) == 0: + symmetric_inds = None + + def random_flip(ex): + """Apply random flip to an example.""" + p = tf.random.uniform((), minval=0, maxval=1.0) + if p <= self.probability: + if self.horizontal: + img_width = tf.shape(ex["image"])[1] + ex["instances"] = flip_instances_lr( + ex["instances"], img_width, symmetric_inds=symmetric_inds + ) + ex["image"] = tf.image.flip_left_right(ex["image"]) + else: + img_height = tf.shape(ex["image"])[0] + ex["instances"] = flip_instances_ud( + ex["instances"], img_height, symmetric_inds=symmetric_inds + ) + ex["image"] = tf.image.flip_up_down(ex["image"]) + return ex + + return input_ds.map(random_flip) diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index eb7b51c2f..5a7544874 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -20,6 +20,7 @@ AugmentationConfig, ImgaugAugmenter, RandomCropper, + RandomFlipper, ) from sleap.nn.data.normalization import Normalizer from sleap.nn.data.resizing import Resizer, PointsRescaler @@ -93,6 +94,7 @@ KeyDeviceMover, PointsRescaler, LambdaMap, + RandomFlipper, ) Provider = TypeVar("Provider", *PROVIDERS) Transformer = TypeVar("Transformer", *TRANSFORMERS) @@ -382,6 +384,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: if self.optimization_config.online_shuffling: pipeline += Shuffler(self.optimization_config.shuffle_buffer_size) + if self.optimization_config.augmentation_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.skeletons[0], + horizontal=self.optimization_config.augmentation_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config ) @@ -669,7 +676,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size ) - + if self.optimization_config.augmentation_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.skeletons[0], + horizontal=self.optimization_config.augmentation_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config ) @@ -802,6 +813,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: ) aug_config = self.optimization_config.augmentation_config + if aug_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.skeletons[0], + horizontal=aug_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 1a65d79bc..478f0c80b 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -14,6 +14,7 @@ import h5py import copy +import operator from enum import Enum from itertools import count from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text @@ -393,7 +394,7 @@ def symmetries(self) -> List[Tuple[Node, Node]]: if edge_type == EdgeType.SYMMETRY ] # Get rid of duplicates - symmetries = list(set([tuple(set(e)) for e in symmetries])) + symmetries = list(set([tuple(sorted(e, key=operator.attrgetter("name"))) for e in symmetries])) return symmetries @property @@ -413,6 +414,16 @@ def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: if attr["type"] == EdgeType.SYMMETRY ] + @property + def symmetric_inds(self) -> np.ndarray: + """Return the symmetric nodes as an array of indices.""" + return np.array( + [ + [self.nodes.index(node1), self.nodes.index(node2)] + for node1, node2 in self.symmetries + ] + ) + def node_to_index(self, node: NodeRef) -> int: """ Return the index of the node, accepts either `Node` or name. diff --git a/sleap/training_profiles/baseline.centroid.json b/sleap/training_profiles/baseline.centroid.json index 1393b8f7f..d46a3694f 100755 --- a/sleap/training_profiles/baseline.centroid.json +++ b/sleap/training_profiles/baseline.centroid.json @@ -72,7 +72,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_large_rf.bottomup.json b/sleap/training_profiles/baseline_large_rf.bottomup.json index f54063945..eef85ec56 100644 --- a/sleap/training_profiles/baseline_large_rf.bottomup.json +++ b/sleap/training_profiles/baseline_large_rf.bottomup.json @@ -81,7 +81,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_large_rf.single.json b/sleap/training_profiles/baseline_large_rf.single.json index 222d60558..f35f81be2 100644 --- a/sleap/training_profiles/baseline_large_rf.single.json +++ b/sleap/training_profiles/baseline_large_rf.single.json @@ -72,7 +72,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_large_rf.topdown.json b/sleap/training_profiles/baseline_large_rf.topdown.json index 9c75a09c7..1156b64a0 100644 --- a/sleap/training_profiles/baseline_large_rf.topdown.json +++ b/sleap/training_profiles/baseline_large_rf.topdown.json @@ -73,7 +73,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_medium_rf.bottomup.json b/sleap/training_profiles/baseline_medium_rf.bottomup.json index ae53a9481..c75e54e7e 100644 --- a/sleap/training_profiles/baseline_medium_rf.bottomup.json +++ b/sleap/training_profiles/baseline_medium_rf.bottomup.json @@ -81,7 +81,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_medium_rf.single.json b/sleap/training_profiles/baseline_medium_rf.single.json index b01000c0f..152fdbb9a 100644 --- a/sleap/training_profiles/baseline_medium_rf.single.json +++ b/sleap/training_profiles/baseline_medium_rf.single.json @@ -72,7 +72,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_medium_rf.topdown.json b/sleap/training_profiles/baseline_medium_rf.topdown.json index fb54f6cf3..0cdfe1cca 100755 --- a/sleap/training_profiles/baseline_medium_rf.topdown.json +++ b/sleap/training_profiles/baseline_medium_rf.topdown.json @@ -73,7 +73,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.bottomup.json b/sleap/training_profiles/pretrained.bottomup.json index 22e2abf71..3b0e20112 100644 --- a/sleap/training_profiles/pretrained.bottomup.json +++ b/sleap/training_profiles/pretrained.bottomup.json @@ -78,7 +78,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.centroid.json b/sleap/training_profiles/pretrained.centroid.json index 45a2d9ea5..a535688e6 100644 --- a/sleap/training_profiles/pretrained.centroid.json +++ b/sleap/training_profiles/pretrained.centroid.json @@ -69,7 +69,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.single.json b/sleap/training_profiles/pretrained.single.json index 572b84ccd..1dfb8453f 100644 --- a/sleap/training_profiles/pretrained.single.json +++ b/sleap/training_profiles/pretrained.single.json @@ -69,7 +69,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.topdown.json b/sleap/training_profiles/pretrained.topdown.json index c49e06a63..25cf1c54b 100644 --- a/sleap/training_profiles/pretrained.topdown.json +++ b/sleap/training_profiles/pretrained.topdown.json @@ -70,7 +70,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/tests/nn/data/test_augmentation.py b/tests/nn/data/test_augmentation.py index 68801a3d9..13751e4e4 100644 --- a/tests/nn/data/test_augmentation.py +++ b/tests/nn/data/test_augmentation.py @@ -1,5 +1,6 @@ import numpy as np import tensorflow as tf +import sleap from sleap.nn.system import use_cpu_only use_cpu_only() # hide GPUs for test @@ -51,6 +52,104 @@ def test_random_cropper(min_labels): assert "crop_bbox" in example offset = tf.stack([example["crop_bbox"][0, 1], example["crop_bbox"][0, 0]], axis=-1) assert tf.reduce_all( - example["instances"] - == (example_preaug["instances"] - tf.expand_dims(offset, axis=0)) + example["instances"] == ( + example_preaug["instances"] - tf.expand_dims(offset, axis=0))) + + +def test_flip_instances_lr(): + insts = tf.cast([ + [[0, 1], [2, 3]], + [[4, 5], [6, 7]], + ], tf.float32) + + insts_flip = augmentation.flip_instances_lr(insts, 8) + np.testing.assert_array_equal(insts_flip, [ + [[7, 1], [5, 3]], + [[3, 5], [1, 7]] + ]) + + insts_flip1 = augmentation.flip_instances_lr(insts, 8, [[0, 1]]) + insts_flip2 = augmentation.flip_instances_lr(insts, 8, [[1, 0]]) + np.testing.assert_array_equal(insts_flip1, [ + [[5, 3], [7, 1]], + [[1, 7], [3, 5]] + ]) + np.testing.assert_array_equal(insts_flip1, insts_flip2) + + +def test_flip_instances_ud(): + insts = tf.cast([ + [[0, 1], [2, 3]], + [[4, 5], [6, 7]], + ], tf.float32) + + insts_flip = augmentation.flip_instances_ud(insts, 8) + np.testing.assert_array_equal(insts_flip, [ + [[0, 6], [2, 4]], + [[4, 2], [6, 0]] + ]) + + insts_flip1 = augmentation.flip_instances_ud(insts, 8, [[0, 1]]) + insts_flip2 = augmentation.flip_instances_ud(insts, 8, [[1, 0]]) + np.testing.assert_array_equal(insts_flip1, [ + [[2, 4], [0, 6]], + [[6, 0], [4, 2]] + ]) + np.testing.assert_array_equal(insts_flip1, insts_flip2) + + +def test_random_flipper(): + vid = sleap.Video.from_filename( + "tests/data/json_format_v1/centered_pair_low_quality.mp4" + ) + skel = sleap.Skeleton.from_names_and_edge_inds(["A", "BL", "BR"], [[0, 1], [0, 2]]) + labels = sleap.Labels([sleap.LabeledFrame(video=vid, frame_idx=0, instances=[ + sleap.Instance.from_pointsarray([[25, 50], [50, 25], [25, 25]], skeleton=skel), + sleap.Instance.from_pointsarray([[125, 150], [150, 125], [125, 125]], skeleton=skel), + ])]) + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=True, probability=1.) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0][:, ::-1]) + np.testing.assert_array_equal( + ex["instances"], + [[[358., 50.], [333., 25.], [358., 25.]], + [[258., 150.], [233., 125.], [258., 125.]]] + ) + + skel.add_symmetry("BL", "BR") + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=True, probability=1.) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0][:, ::-1]) + np.testing.assert_array_equal( + ex["instances"], + [[[358., 50.], [358., 25.], [333., 25.]], + [[258., 150.], [258., 125.], [233., 125.]]] + ) + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=True, probability=0.) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0]) + np.testing.assert_array_equal( + ex["instances"], + [[[25, 50], [50, 25], [25, 25]], + [[125, 150], [150, 125], [125, 125]]] + ) + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=False, probability=1.) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0][::-1, :]) + np.testing.assert_array_equal( + ex["instances"], + [[[25, 333], [25, 358], [50, 358]], + [[125, 233], [125, 258], [150, 258]]] ) diff --git a/tests/test_instance.py b/tests/test_instance.py index 922bde08e..0d0992cde 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -34,12 +34,20 @@ def test_instance_node_get_set_item(skeleton): thorax_point = instance["thorax"] assert math.isnan(thorax_point.x) and math.isnan(thorax_point.y) + instance[0] = [-20, -50] + assert instance["head"].x == -20 + assert instance["head"].y == -50 + + instance[0] = np.array([-21, -51]) + assert instance["head"].x == -21 + assert instance["head"].y == -51 + def test_instance_node_multi_get_set_item(skeleton): """ Test basic get item and set item functionality of instances. """ - node_names = ["left-wing", "head", "right-wing"] + node_names = ["head", "left-wing", "right-wing"] points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} instance1 = Instance(skeleton=skeleton, points=points) @@ -51,6 +59,24 @@ def test_instance_node_multi_get_set_item(skeleton): assert np.allclose(x_values, [1, 2, 3]) assert np.allclose(y_values, [4, 5, 6]) + + np.testing.assert_array_equal( + instance1[np.array([0, 2, 4])], + [[1, 4], [np.nan, np.nan], [3, 6]] + ) + + instance1[np.array([0, 1])] = [[1, 2], [3, 4]] + np.testing.assert_array_equal(instance1[np.array([0, 1])], [[1, 2], [3, 4]]) + + instance1[[0, 1]] = [[4, 3], [2, 1]] + np.testing.assert_array_equal(instance1[np.array([0, 1])], [[4, 3], [2, 1]]) + + instance1[["left-wing", "right-wing"]] = [[-4, -3], [-2, -1]] + np.testing.assert_array_equal(instance1[np.array([3, 4])], [[-4, -3], [-2, -1]]) + assert instance1["left-wing"].x == -4 + assert instance1["left-wing"].y == -3 + assert instance1["right-wing"].x == -2 + assert instance1["right-wing"].y == -1 def test_non_exist_node(skeleton): diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index acdda25b1..ad3667a94 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -144,6 +144,14 @@ def test_symmetry(): s1.add_symmetry("1", "5") s1.add_symmetry("3", "6") + assert (s1.nodes[0], s1.nodes[4]) in s1.symmetries + assert (s1.nodes[2], s1.nodes[5]) in s1.symmetries + assert len(s1.symmetries) == 2 + + assert (0, 4) in s1.symmetric_inds + assert (2, 5) in s1.symmetric_inds + assert len(s1.symmetric_inds) == 2 + assert s1.get_symmetry("1").name == "5" assert s1.get_symmetry("5").name == "1"