Skip to content

Commit

Permalink
Add tests and catches for edge cases of empty detections
Browse files Browse the repository at this point in the history
- added EMPTY_STATE in state.py and tested it

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed May 15, 2024
1 parent 109f121 commit 1e7f41e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 30 deletions.
2 changes: 1 addition & 1 deletion configs/DGS/eval_sim_indep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dl_rcnn:
module_name: "KeypointRCNNImageBackbone"
dataset_path: "./data/"
base_path: "./data/PoseTrack21/images/val/"
batch_size: 32
batch_size: 16
threshold: 0.75
return_lists: true
crop_size: !!python/tuple [256, 192]
Expand Down
14 changes: 11 additions & 3 deletions dgs/models/dataset/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dgs.utils.constants import IMAGE_FORMATS, VIDEO_FORMATS
from dgs.utils.files import is_dir, is_file
from dgs.utils.image import CustomToAspect, load_image
from dgs.utils.state import State
from dgs.utils.state import EMPTY_STATE, State
from dgs.utils.types import Config, FilePath, FilePaths, Image, Images, NodePath, Validations
from dgs.utils.utils import extract_crops_from_images

Expand Down Expand Up @@ -74,15 +74,23 @@ def images_to_states(self, images: Images) -> list[State]:
With the filepath given in the state, the image can be reloaded if required.
"""

outputs = self.model(images)
# predict list of {boxes: XYWH[N], labels: Int64[N], scores: [N], keypoints: Float[N,J,(x|y|vis)]}
# every image in images can have multiple predictions
outputs: list[dict[str, torch.Tensor]] = self.model(images)

states: list[State] = []
canvas_size = (max(i.shape[-2] for i in images), max(i.shape[-1] for i in images))

for output, image in zip(outputs, images):
# for every image (output), get the indices where the score is bigger than the threshold
# get the output for every image independently
# get the indices where the score ('certainty') is bigger than the given threshold
indices = output["scores"] > self.threshold

# skip if there aren't any detections
if not torch.any(indices):
states.append(EMPTY_STATE.copy())
continue

# bbox given in XYXY format
bbox = tvte.BoundingBoxes(output["boxes"][indices], format="XYXY", canvas_size=canvas_size)
# keypoints in [x,y,v] format -> kp, vis
Expand Down
11 changes: 11 additions & 0 deletions dgs/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,11 @@ def clean(self, keys: Union[list[str], str] = None) -> "State":
return self


EMPTY_STATE = State(
bbox=tv_tensors.BoundingBoxes(torch.empty((0, 4)), canvas_size=(0, 0), format="XYXY"), validate=False
)


def get_ds_data_getter(attributes: list[str]) -> DataGetter:
"""Given a list of attribute names,
return a function, that gets those attributes from a given :class:`State`.
Expand Down Expand Up @@ -807,6 +812,12 @@ def collate_states(batch: Union[list[State], State]) -> State:
if isinstance(batch, State):
return batch

# remove all empty states and return early for length 0 or 1
batch = [b for b in batch if b.B != 0]

if len(batch) == 0:
return EMPTY_STATE.copy()

if len(batch) == 1:
return batch[0]

Expand Down
27 changes: 16 additions & 11 deletions tests/models/dataset/test__pt21.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dgs.models.loader import get_data_loader
from dgs.utils.config import load_config
from dgs.utils.files import is_abs_dir, mkdir_if_missing
from dgs.utils.state import State
from dgs.utils.state import EMPTY_STATE, State
from dgs.utils.utils import HidePrint
from tests.utils.state import *

Expand Down Expand Up @@ -232,16 +232,21 @@ def test_dataloader(self):
self.assertTrue(isinstance(batch, State))
self.assertEqual(len(batch), B)

if B != 0:
# check the number of dimensions
self.assertEqual(batch.class_id.ndim, 1)
self.assertTrue(all(img.ndim == 4 for img in batch.image))
self.assertEqual(batch.image_crop.ndim, 4)
self.assertEqual(batch.joint_weight.ndim, 3)
self.assertEqual(batch.keypoints.ndim, 3)
self.assertEqual(batch.keypoints_local.ndim, 3)
self.assertEqual(batch.person_id.ndim, 1)
self.assertEqual(batch.class_id.ndim, 1)
if B == 0:
self.assertEqual(batch, EMPTY_STATE)
for k in ["image_crop", "joint_weights", "keypoints", "keypoints_local", "person_id", "class_id"]:
self.assertTrue(k not in batch)
continue

# check the number of dimensions
self.assertEqual(batch.class_id.ndim, 1)
self.assertTrue(all(img.ndim == 4 for img in batch.image))
self.assertEqual(batch.image_crop.ndim, 4)
self.assertEqual(batch.joint_weight.ndim, 3)
self.assertEqual(batch.keypoints.ndim, 3)
self.assertEqual(batch.keypoints_local.ndim, 3)
self.assertEqual(batch.person_id.ndim, 1)
self.assertEqual(batch.class_id.ndim, 1)

# check that the first dimension is B
self.assertEqual(batch.class_id.size(0), B)
Expand Down
60 changes: 45 additions & 15 deletions tests/utils/state/test__collate_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import torch
from torchvision.tv_tensors import BoundingBoxes, Image, TVTensor

from dgs.utils.state import collate_bboxes, collate_devices, collate_states, collate_tensors, collate_tvt_tensors, State
from dgs.utils.state import (
collate_bboxes,
collate_devices,
collate_states,
collate_tensors,
collate_tvt_tensors,
EMPTY_STATE,
State,
)

N = 10
J = 17
Expand Down Expand Up @@ -130,29 +138,29 @@ def test_tensors(self):


class TestCollateStates(unittest.TestCase):
bbox = BoundingBoxes(torch.ones(4), format="XYWH", canvas_size=(100, 100))

def test_states(self):
def test_collate_states(self):
for validate in [True, False]:
bbox = BoundingBoxes(torch.ones(4), format="XYWH", canvas_size=(100, 100))

s = State(
bbox=bbox, keypoints=torch.ones(1, J, j_dim), image=[Image(torch.ones(1, C, H, W))], validate=validate
bbox=self.bbox,
keypoints=torch.ones(1, J, j_dim),
image=[Image(torch.ones(1, C, H, W))],
validate=validate,
)
n_states = State(
bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)),
keypoints=torch.ones((N, J, j_dim)),
image=[Image(torch.ones(1, C, H, W)) for _ in range(N)],
validate=validate,
)

for states, result in [
([s], s),
(s, s),
([s for _ in range(N)], n_states),
(
[s for _ in range(N)],
State(
bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)),
keypoints=torch.ones((N, J, j_dim)),
image=[Image(torch.ones(1, C, H, W)) for _ in range(N)],
validate=validate,
),
),
(
[State(bbox=bbox, str="dummy", tuple=(1,), validate=validate) for _ in range(N)],
[State(bbox=self.bbox, str="dummy", tuple=(1,), validate=validate) for _ in range(N)],
State(
bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)),
str=tuple("dummy" for _ in range(N)),
Expand All @@ -165,6 +173,28 @@ def test_states(self):
self.assertTrue(collate_states(states) == result)
self.assertEqual(result.validate, validate)

def test_empty_states(self):
s = State(
bbox=self.bbox, keypoints=torch.ones(1, J, j_dim), image=[Image(torch.ones(1, C, H, W))], validate=False
)
n_states = State(
bbox=BoundingBoxes(torch.ones((N, 4)), format="XYWH", canvas_size=(100, 100)),
keypoints=torch.ones((N, J, j_dim)),
image=[Image(torch.ones(1, C, H, W)) for _ in range(N)],
validate=False,
)

for states, result in [
([], EMPTY_STATE),
([EMPTY_STATE], EMPTY_STATE),
(EMPTY_STATE, EMPTY_STATE),
([EMPTY_STATE, EMPTY_STATE], EMPTY_STATE),
([EMPTY_STATE, s], s),
([v for sublist in [[s, EMPTY_STATE] for _ in range(N)] for v in sublist], n_states), # [s,ES,s,ES,...]
]:
with self.subTest(msg="s: {}, res: {}".format(len(states), result.B)):
self.assertTrue(collate_states(states) == result)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1e7f41e

Please sign in to comment.