Skip to content

Commit

Permalink
Add track ids if present for ultralytics models (#4569)
Browse files Browse the repository at this point in the history
* add track ids for ultralytics detections if present

* add track ids for other tasks

---------

Co-authored-by: Rusty Nail <[email protected]>
  • Loading branch information
Rusteam and Rusty Nail authored Jul 12, 2024
1 parent ff903a1 commit 61afb2d
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions fiftyone/utils/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def convert_ultralytics_model(model):
)


def _extract_track_ids(result):
"""Get ultralytics track ids if present, else use Nones"""
return (
result.boxes.id.detach().cpu().numpy().astype(int)
if result.boxes.is_track
else [None] * len(result.boxes.conf.size(0))
)


def to_detections(results, confidence_thresh=None):
"""Converts ``ultralytics.YOLO`` boxes to FiftyOne format.
Expand Down Expand Up @@ -84,9 +93,10 @@ def _to_detections(result, confidence_thresh=None):
classes = np.rint(result.boxes.cls.detach().cpu().numpy()).astype(int)
boxes = result.boxes.xywhn.detach().cpu().numpy().astype(float)
confs = result.boxes.conf.detach().cpu().numpy().astype(float)
track_ids = _extract_track_ids(result)

detections = []
for cls, box, conf in zip(classes, boxes, confs):
for cls, box, conf, idx in zip(classes, boxes, confs, track_ids):
if confidence_thresh is not None and conf < confidence_thresh:
continue

Expand All @@ -97,6 +107,7 @@ def _to_detections(result, confidence_thresh=None):
label=label,
bounding_box=[xc - 0.5 * w, yc - 0.5 * h, w, h],
confidence=conf,
index=idx,
)
detections.append(detection)

Expand Down Expand Up @@ -135,13 +146,16 @@ def _to_instances(result, confidence_thresh=None):
boxes = result.boxes.xywhn.detach().cpu().numpy().astype(float)
masks = result.masks.data.detach().cpu().numpy() > 0.5
confs = result.boxes.conf.detach().cpu().numpy().astype(float)
track_ids = _extract_track_ids(result)

# convert from center coords to corner coords
boxes[:, 0] -= boxes[:, 2] / 2.0
boxes[:, 1] -= boxes[:, 3] / 2.0

detections = []
for cls, box, mask, conf in zip(classes, boxes, masks, confs):
for cls, box, mask, conf, idx in zip(
classes, boxes, masks, confs, track_ids
):
if confidence_thresh is not None and conf < confidence_thresh:
continue

Expand All @@ -163,6 +177,7 @@ def _to_instances(result, confidence_thresh=None):
bounding_box=list(box),
mask=sub_mask.astype(bool),
confidence=conf,
index=idx,
)
detections.append(detection)

Expand Down Expand Up @@ -258,6 +273,7 @@ def _to_polylines(result, tolerance, filled, confidence_thresh=None):

classes = np.rint(result.boxes.cls.detach().cpu().numpy()).astype(int)
confs = result.boxes.conf.detach().cpu().numpy().astype(float)
track_ids = _extract_track_ids(result)

if tolerance > 1:
masks = result.masks.data.detach().cpu().numpy() > 0.5
Expand All @@ -267,7 +283,9 @@ def _to_polylines(result, tolerance, filled, confidence_thresh=None):
points = result.masks.xyn

polylines = []
for cls, mask, _points, conf in zip(classes, masks, points, confs):
for cls, mask, _points, conf, idx in zip(
classes, masks, points, confs, track_ids
):
if confidence_thresh is not None and conf < confidence_thresh:
continue

Expand All @@ -284,6 +302,7 @@ def _to_polylines(result, tolerance, filled, confidence_thresh=None):
confidence=conf,
closed=True,
filled=filled,
index=idx,
)
polylines.append(polyline)

Expand Down Expand Up @@ -324,9 +343,10 @@ def _to_keypoints(result, confidence_thresh=None):
confs = result.keypoints.conf.detach().cpu().numpy().astype(float)
else:
confs = itertools.repeat(None)
track_ids = _extract_track_ids(result)

keypoints = []
for cls, _points, _confs in zip(classes, points, confs):
for cls, _points, _confs, idx in zip(classes, points, confs, track_ids):
if confidence_thresh is not None:
_points[_confs < confidence_thresh] = np.nan

Expand All @@ -337,6 +357,7 @@ def _to_keypoints(result, confidence_thresh=None):
label=label,
points=_points.tolist(),
confidence=_confidence,
index=idx,
)
keypoints.append(keypoint)

Expand Down

0 comments on commit 61afb2d

Please sign in to comment.