diff --git a/utils/segment/plots.py b/utils/segment/plots.py index 6303103ed084..4517ff455cba 100644 --- a/utils/segment/plots.py +++ b/utils/segment/plots.py @@ -93,8 +93,8 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg' if paths: annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames if len(targets) > 0: - j = targets[:, 0] == i - ti = targets[j] # image targets + idx = targets[:, 0] == i + ti = targets[idx] # image targets boxes = xywh2xyxy(ti[:, 2:6]).T classes = ti[:, 1].astype('int') @@ -126,13 +126,13 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg' image_masks = np.repeat(image_masks, nl, axis=0) image_masks = np.where(image_masks == index, 1.0, 0.0) else: - image_masks = masks[j] + image_masks = masks[idx] im = np.asarray(annotator.im).copy() for j, box in enumerate(boxes.T.tolist()): if labels or conf[j] > 0.25: # 0.25 conf thresh color = colors(classes[j]) - mh, mw = image_masks[j].shape[:2] + mh, mw = image_masks[j].shape if mh != h or mw != w: mask = image_masks[j].astype(np.uint8) mask = cv2.resize(mask, (w, h))