Skip to content

Commit

Permalink
add type checking for CLASSES (open-mmlab#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J authored Mar 14, 2022
1 parent 663cfc8 commit afb740f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
3 changes: 3 additions & 0 deletions demo/webcam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def process_mmdet_results(mmdet_results, class_names=None, cat_ids=1):
if isinstance(mmdet_results, tuple):
mmdet_results = mmdet_results[0]

if isinstance(class_names, str):
class_names = (class_names, )

if not isinstance(cat_ids, (list, tuple)):
cat_ids = [cat_ids]

Expand Down
10 changes: 7 additions & 3 deletions tools/webcam/webcam_apis/nodes/mmdet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ def _post_process(self, preds):
dets = preds
segms = [None] * len(dets)

assert len(dets) == len(self.model.CLASSES)
assert len(segms) == len(self.model.CLASSES)
det_model_classes = self.model.CLASSES
if isinstance(det_model_classes, str):
det_model_classes = (det_model_classes, )

assert len(dets) == len(det_model_classes)
assert len(segms) == len(det_model_classes)
result = {'preds': [], 'model_cfg': self.model.cfg.copy()}

for i, (cls_name, bboxes,
masks) in enumerate(zip(self.model.CLASSES, dets, segms)):
masks) in enumerate(zip(det_model_classes, dets, segms)):
if masks is None:
masks = [None] * len(bboxes)
else:
Expand Down

0 comments on commit afb740f

Please sign in to comment.