diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py index bff300121d..e3801a38d3 100644 --- a/demo/webcam_demo.py +++ b/demo/webcam_demo.py @@ -168,6 +168,9 @@ def process_mmdet_results(mmdet_results, class_names=None, cat_ids=1): if isinstance(mmdet_results, tuple): mmdet_results = mmdet_results[0] + if isinstance(class_names, str): + class_names = (class_names, ) + if not isinstance(cat_ids, (list, tuple)): cat_ids = [cat_ids] diff --git a/tools/webcam/webcam_apis/nodes/mmdet_node.py b/tools/webcam/webcam_apis/nodes/mmdet_node.py index 4207647c92..2ff079b3c5 100644 --- a/tools/webcam/webcam_apis/nodes/mmdet_node.py +++ b/tools/webcam/webcam_apis/nodes/mmdet_node.py @@ -62,12 +62,16 @@ def _post_process(self, preds): dets = preds segms = [None] * len(dets) - assert len(dets) == len(self.model.CLASSES) - assert len(segms) == len(self.model.CLASSES) + det_model_classes = self.model.CLASSES + if isinstance(det_model_classes, str): + det_model_classes = (det_model_classes, ) + + assert len(dets) == len(det_model_classes) + assert len(segms) == len(det_model_classes) result = {'preds': [], 'model_cfg': self.model.cfg.copy()} for i, (cls_name, bboxes, - masks) in enumerate(zip(self.model.CLASSES, dets, segms)): + masks) in enumerate(zip(det_model_classes, dets, segms)): if masks is None: masks = [None] * len(bboxes) else: