Skip to content

Commit

Permalink
add track ids for other tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Rusteam committed Jul 12, 2024
1 parent fa78da0 commit d3731be
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions fiftyone/utils/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def convert_ultralytics_model(model):
)


def _extract_track_ids(result):
"""Get ultralytics track ids if present, else use Nones"""
return (
result.boxes.id.detach().cpu().numpy().astype(int)
if result.boxes.is_track
else [None] * len(result.boxes.conf.size(0))
)


def to_detections(results, confidence_thresh=None):
"""Converts ``ultralytics.YOLO`` boxes to FiftyOne format.
Expand Down Expand Up @@ -84,11 +93,7 @@ def _to_detections(result, confidence_thresh=None):
classes = np.rint(result.boxes.cls.detach().cpu().numpy()).astype(int)
boxes = result.boxes.xywhn.detach().cpu().numpy().astype(float)
confs = result.boxes.conf.detach().cpu().numpy().astype(float)
track_ids = (
result.boxes.id.detach().cpu().numpy().astype(int)
if result.boxes.is_track
else [None] * len(boxes)
)
track_ids = _extract_track_ids(result)

detections = []
for cls, box, conf, idx in zip(classes, boxes, confs, track_ids):
Expand Down Expand Up @@ -141,13 +146,16 @@ def _to_instances(result, confidence_thresh=None):
boxes = result.boxes.xywhn.detach().cpu().numpy().astype(float)
masks = result.masks.data.detach().cpu().numpy() > 0.5
confs = result.boxes.conf.detach().cpu().numpy().astype(float)
track_ids = _extract_track_ids(result)

# convert from center coords to corner coords
boxes[:, 0] -= boxes[:, 2] / 2.0
boxes[:, 1] -= boxes[:, 3] / 2.0

detections = []
for cls, box, mask, conf in zip(classes, boxes, masks, confs):
for cls, box, mask, conf, idx in zip(
classes, boxes, masks, confs, track_ids
):
if confidence_thresh is not None and conf < confidence_thresh:
continue

Expand All @@ -169,6 +177,7 @@ def _to_instances(result, confidence_thresh=None):
bounding_box=list(box),
mask=sub_mask.astype(bool),
confidence=conf,
index=idx,
)
detections.append(detection)

Expand Down Expand Up @@ -264,6 +273,7 @@ def _to_polylines(result, tolerance, filled, confidence_thresh=None):

classes = np.rint(result.boxes.cls.detach().cpu().numpy()).astype(int)
confs = result.boxes.conf.detach().cpu().numpy().astype(float)
track_ids = _extract_track_ids(result)

if tolerance > 1:
masks = result.masks.data.detach().cpu().numpy() > 0.5
Expand All @@ -273,7 +283,9 @@ def _to_polylines(result, tolerance, filled, confidence_thresh=None):
points = result.masks.xyn

polylines = []
for cls, mask, _points, conf in zip(classes, masks, points, confs):
for cls, mask, _points, conf, idx in zip(
classes, masks, points, confs, track_ids
):
if confidence_thresh is not None and conf < confidence_thresh:
continue

Expand All @@ -290,6 +302,7 @@ def _to_polylines(result, tolerance, filled, confidence_thresh=None):
confidence=conf,
closed=True,
filled=filled,
index=idx,
)
polylines.append(polyline)

Expand Down Expand Up @@ -330,9 +343,10 @@ def _to_keypoints(result, confidence_thresh=None):
confs = result.keypoints.conf.detach().cpu().numpy().astype(float)
else:
confs = itertools.repeat(None)
track_ids = _extract_track_ids(result)

keypoints = []
for cls, _points, _confs in zip(classes, points, confs):
for cls, _points, _confs, idx in zip(classes, points, confs, track_ids):
if confidence_thresh is not None:
_points[_confs < confidence_thresh] = np.nan

Expand All @@ -343,6 +357,7 @@ def _to_keypoints(result, confidence_thresh=None):
label=label,
points=_points.tolist(),
confidence=_confidence,
index=idx,
)
keypoints.append(keypoint)

Expand Down

0 comments on commit d3731be

Please sign in to comment.