From 1e7f41e6e8f13e150ac6957d628e9d4ddc29b7c0 Mon Sep 17 00:00:00 2001 From: Martin <1500595+bmmtstb@users.noreply.github.com> Date: Wed, 15 May 2024 13:04:34 +0200 Subject: [PATCH] Add tests and catches for edge cases of empty detections - added EMPTY_STATE in state.py and tested it Signed-off-by: Martin <1500595+bmmtstb@users.noreply.github.com> --- configs/DGS/eval_sim_indep.yaml | 2 +- dgs/models/dataset/keypoint_rcnn.py | 14 ++++-- dgs/utils/state.py | 11 +++++ tests/models/dataset/test__pt21.py | 27 ++++++----- tests/utils/state/test__collate_state.py | 60 ++++++++++++++++++------ 5 files changed, 84 insertions(+), 30 deletions(-) diff --git a/configs/DGS/eval_sim_indep.yaml b/configs/DGS/eval_sim_indep.yaml index 5dcc505..1f6b700 100644 --- a/configs/DGS/eval_sim_indep.yaml +++ b/configs/DGS/eval_sim_indep.yaml @@ -20,7 +20,7 @@ dl_rcnn: module_name: "KeypointRCNNImageBackbone" dataset_path: "./data/" base_path: "./data/PoseTrack21/images/val/" - batch_size: 32 + batch_size: 16 threshold: 0.75 return_lists: true crop_size: !!python/tuple [256, 192] diff --git a/dgs/models/dataset/keypoint_rcnn.py b/dgs/models/dataset/keypoint_rcnn.py index 2626978..5bc9fb2 100644 --- a/dgs/models/dataset/keypoint_rcnn.py +++ b/dgs/models/dataset/keypoint_rcnn.py @@ -22,7 +22,7 @@ from dgs.utils.constants import IMAGE_FORMATS, VIDEO_FORMATS from dgs.utils.files import is_dir, is_file from dgs.utils.image import CustomToAspect, load_image -from dgs.utils.state import State +from dgs.utils.state import EMPTY_STATE, State from dgs.utils.types import Config, FilePath, FilePaths, Image, Images, NodePath, Validations from dgs.utils.utils import extract_crops_from_images @@ -74,15 +74,23 @@ def images_to_states(self, images: Images) -> list[State]: With the filepath given in the state, the image can be reloaded if required. """ - outputs = self.model(images) + # predict list of {boxes: XYWH[N], labels: Int64[N], scores: [N], keypoints: Float[N,J,(x|y|vis)]} + # every image in images can have multiple predictions + outputs: list[dict[str, torch.Tensor]] = self.model(images) states: list[State] = [] canvas_size = (max(i.shape[-2] for i in images), max(i.shape[-1] for i in images)) for output, image in zip(outputs, images): - # for every image (output), get the indices where the score is bigger than the threshold + # get the output for every image independently + # get the indices where the score ('certainty') is bigger than the given threshold indices = output["scores"] > self.threshold + # skip if there aren't any detections + if not torch.any(indices): + states.append(EMPTY_STATE.copy()) + continue + # bbox given in XYXY format bbox = tvte.BoundingBoxes(output["boxes"][indices], format="XYXY", canvas_size=canvas_size) # keypoints in [x,y,v] format -> kp, vis diff --git a/dgs/utils/state.py b/dgs/utils/state.py index 0d55341..83ab5d0 100644 --- a/dgs/utils/state.py +++ b/dgs/utils/state.py @@ -725,6 +725,11 @@ def clean(self, keys: Union[list[str], str] = None) -> "State": return self +EMPTY_STATE = State( + bbox=tv_tensors.BoundingBoxes(torch.empty((0, 4)), canvas_size=(0, 0), format="XYXY"), validate=False +) + + def get_ds_data_getter(attributes: list[str]) -> DataGetter: """Given a list of attribute names, return a function, that gets those attributes from a given :class:`State`. @@ -807,6 +812,12 @@ def collate_states(batch: Union[list[State], State]) -> State: if isinstance(batch, State): return batch + # remove all empty states and return early for length 0 or 1 + batch = [b for b in batch if b.B != 0] + + if len(batch) == 0: + return EMPTY_STATE.copy() + if len(batch) == 1: return batch[0] diff --git a/tests/models/dataset/test__pt21.py b/tests/models/dataset/test__pt21.py index 146d2c2..0fbd2af 100644 --- a/tests/models/dataset/test__pt21.py +++ b/tests/models/dataset/test__pt21.py @@ -14,7 +14,7 @@ from dgs.models.loader import get_data_loader from dgs.utils.config import load_config from dgs.utils.files import is_abs_dir, mkdir_if_missing -from dgs.utils.state import State +from dgs.utils.state import EMPTY_STATE, State from dgs.utils.utils import HidePrint from tests.utils.state import * @@ -232,16 +232,21 @@ def test_dataloader(self): self.assertTrue(isinstance(batch, State)) self.assertEqual(len(batch), B) - if B != 0: - # check the number of dimensions - self.assertEqual(batch.class_id.ndim, 1) - self.assertTrue(all(img.ndim == 4 for img in batch.image)) - self.assertEqual(batch.image_crop.ndim, 4) - self.assertEqual(batch.joint_weight.ndim, 3) - self.assertEqual(batch.keypoints.ndim, 3) - self.assertEqual(batch.keypoints_local.ndim, 3) - self.assertEqual(batch.person_id.ndim, 1) - self.assertEqual(batch.class_id.ndim, 1) + if B == 0: + self.assertEqual(batch, EMPTY_STATE) + for k in ["image_crop", "joint_weights", "keypoints", "keypoints_local", "person_id", "class_id"]: + self.assertTrue(k not in batch) + continue + + # check the number of dimensions + self.assertEqual(batch.class_id.ndim, 1) + self.assertTrue(all(img.ndim == 4 for img in batch.image)) + self.assertEqual(batch.image_crop.ndim, 4) + self.assertEqual(batch.joint_weight.ndim, 3) + self.assertEqual(batch.keypoints.ndim, 3) + self.assertEqual(batch.keypoints_local.ndim, 3) + self.assertEqual(batch.person_id.ndim, 1) + self.assertEqual(batch.class_id.ndim, 1) # check that the first dimension is B self.assertEqual(batch.class_id.size(0), B) diff --git a/tests/utils/state/test__collate_state.py b/tests/utils/state/test__collate_state.py index 905dbc7..effcd34 100644 --- a/tests/utils/state/test__collate_state.py +++ b/tests/utils/state/test__collate_state.py @@ -3,7 +3,15 @@ import torch from torchvision.tv_tensors import BoundingBoxes, Image, TVTensor -from dgs.utils.state import collate_bboxes, collate_devices, collate_states, collate_tensors, collate_tvt_tensors, State +from dgs.utils.state import ( + collate_bboxes, + collate_devices, + collate_states, + collate_tensors, + collate_tvt_tensors, + EMPTY_STATE, + State, +) N = 10 J = 17 @@ -130,29 +138,29 @@ def test_tensors(self): class TestCollateStates(unittest.TestCase): + bbox = BoundingBoxes(torch.ones(4), format="XYWH", canvas_size=(100, 100)) - def test_states(self): + def test_collate_states(self): for validate in [True, False]: - bbox = BoundingBoxes(torch.ones(4), format="XYWH", canvas_size=(100, 100)) - s = State( - bbox=bbox, keypoints=torch.ones(1, J, j_dim), image=[Image(torch.ones(1, C, H, W))], validate=validate + bbox=self.bbox, + keypoints=torch.ones(1, J, j_dim), + image=[Image(torch.ones(1, C, H, W))], + validate=validate, + ) + n_states = State( + bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)), + keypoints=torch.ones((N, J, j_dim)), + image=[Image(torch.ones(1, C, H, W)) for _ in range(N)], + validate=validate, ) for states, result in [ ([s], s), (s, s), + ([s for _ in range(N)], n_states), ( - [s for _ in range(N)], - State( - bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)), - keypoints=torch.ones((N, J, j_dim)), - image=[Image(torch.ones(1, C, H, W)) for _ in range(N)], - validate=validate, - ), - ), - ( - [State(bbox=bbox, str="dummy", tuple=(1,), validate=validate) for _ in range(N)], + [State(bbox=self.bbox, str="dummy", tuple=(1,), validate=validate) for _ in range(N)], State( bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)), str=tuple("dummy" for _ in range(N)), @@ -165,6 +173,28 @@ def test_states(self): self.assertTrue(collate_states(states) == result) self.assertEqual(result.validate, validate) + def test_empty_states(self): + s = State( + bbox=self.bbox, keypoints=torch.ones(1, J, j_dim), image=[Image(torch.ones(1, C, H, W))], validate=False + ) + n_states = State( + bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)), + keypoints=torch.ones((N, J, j_dim)), + image=[Image(torch.ones(1, C, H, W)) for _ in range(N)], + validate=False, + ) + + for states, result in [ + ([], EMPTY_STATE), + ([EMPTY_STATE], EMPTY_STATE), + (EMPTY_STATE, EMPTY_STATE), + ([EMPTY_STATE, EMPTY_STATE], EMPTY_STATE), + ([EMPTY_STATE, s], s), + ([v for sublist in [[s, EMPTY_STATE] for _ in range(N)] for v in sublist], n_states), # [s,ES,s,ES,...] + ]: + with self.subTest(msg="s: {}, res: {}".format(len(states), result.B)): + self.assertTrue(collate_states(states) == result) + if __name__ == "__main__": unittest.main()