-
Notifications
You must be signed in to change notification settings - Fork 2.5k
RuntimeError: The shape of the mask [230202] at index 0 does not match the shape of the indexed tensor [0] at index 0 #31
Comments
Hi, This error usually happens when the learning rate is too high: the gradients explode, and the activations become inf, so that once we try to perform indexing it gets out of range. That would be my first guess, but let me see in a bit more detail to see if it could be something else |
I actually think that this error might come from a different place. I have the impression that one of your images doesn't have a bounding box target. Is that right? |
If that's the case, I'd recommend removing those images during training. |
I have been looking into mmdetection implementation, and they raise an error if there are no ground-truth boxes in the image, see here. |
Hi, Did you have the chance to verify if your dataset contains images without annotations? Let me know what you think. |
Hi, |
I had a similar error this that was caused by a segmentation started for an image but never finished (e.g. A dirty workaround is something like:
|
When I change the code as what you mentioned at #37 , the error get caught by "ValueError: No ground-truth boxes available for one of the images". Sorry for that, I think it may because I check my data set in pascal voc format but the code changing format from pascal voc to coco had some problem. Is it possible when the error was caught to print the filename of image? |
I think it is possible to do so, but we would need to do it in the dataset class itself, somewhere near here https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/datasets/coco.py#L40 if not boxes:
raise ValueError("Image id {} ({}) doesn't have annotations!".format(self.ids[idx], anno)) Let me know what happens. I'm looking into improving the error message for those corner cases. |
@fmassa Hi |
We by default train with images of min size 800 and max size 1333. You can modify it in About the OOM errors, Do you have a very large number of instances in a single image in your dataset ( > 500)? Also, where do you get the OOM error? I think it might be related to the box IOU calculation About the error not being caught earlier, that's weird, could you paste the part that you modified in the |
Yeah. And here is what I modified in COCODataset:
` |
I had modified the initial message that I sent, check it again, there was a bug in the initial version that I posted and now it should show the image that is problematic. The new version is edited in GitHub in this thread |
I got the same issue too. When I check the data, I find that there is an empty box .
And here is what I modified https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/engine/trainer.py#L57 for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
data_time = time.time() - end
arguments["iteration"] = iteration
# add ignore
if len(targets[0]) < 1:
print('num_boxes: ', len(targets[0]))
continue
scheduler.step()
|
@zylhub thanks for the snippet! |
@fmassa yes, I have write my own dataset class, and feeding in coco formats. |
My assumption in this case was that the code I wrote in the I still need to understand why this isn't the case, So that I can make the code more robust. Skipping the batch is definitely a solution, but I'd rather avoid it for now |
@fmassa I think the code at init is not enough to prevent that error. We also need a check after calling "clip_to_image" at getitem to further prevent empty bounding box. |
@xuanyuzhou98 oh, that a very good point! |
@fmassa Hi, I found the reason which may cause the strange phenomenon! |
@liminghao1630 thanks for reaching back. Could you explain a bit more what was happening before, and maybe how we could the asserts more robust? |
I get a similar problem. However, I converted my dataset to mscoco, and I have both bbox and segmentations, but not all images have segmentations in which case I only use bbox'es. As a first try, I decided to run FasterRCNN (MODEL.MASK_ON = False), and in that case set in
Should that be sufficient? In any case, running FasterRCNN shouldn't need the segmentations. As I think my problem is similar, I will investigate it a bit more, to see where it crashes. Edit: I found that a couple of images have bounding boxes with negative coordinates through some faulty files during generation. Is this an error that can be caught? |
This was happening to me because some of my annotations were out of image bounds. Maybe it is worth it to have a check against this. |
@miguelvr were they completely out of the image, or just part of it? |
just part of it, sometimes it was even hard to notice with the drawn bboxes |
Interesting. I'm not yet sure where/why this would crash in the current code, I'd need to check a bit more. |
I have also met the problem and I'm sure all the images in my dataset have bboxes. I' tring to remove out-of-image-boundary bbox annotations and find out will that works. |
@txytju let us know if that was what was necessary to fix your problem |
@fmassa In my case, the problem was caused by some bounding boxes with My workaround was to add an extra check in the # filter images without detection annotations
if remove_images_without_annotations:
total_removed = 0
len_ids_before = len(self.ids)
self.ids = [
img_id
for img_id in self.ids
if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
]
total_removed += len_ids_before - len(self.ids)
if total_removed > 0:
print("{} images were removed because they do not have annotations!".format(total_removed))
ids_to_remove = []
for img_id in self.ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anno = self.coco.loadAnns(ann_ids)
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
img_data = self.coco.loadImgs(img_id)[0]
img_size = (img_data['width'], img_data['height']) # try to speed up by avoid loading the image
target = BoxList(boxes, img_size, mode="xywh").convert("xyxy")
target = target.clip_to_image(remove_empty=True)
if len(target) == 0:
ids_to_remove.append(img_id)
total_removed += len(ids_to_remove)
if len(ids_to_remove) > 0:
self.ids = [img_id for img_id in self.ids if img_id not in ids_to_remove]
print("These {} images were removed after `clip_to_image`: {}".format(
len(ids_to_remove), ', '.join(list(map(str, ids_to_remove)))
))
if total_removed > 0:
print('In total, {} images without annotations were removed'.format(total_removed)) If you prefer the current exception-oriented approach, one exception could be thrown here: if remove_empty:
box = self.bbox
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
if keep.sum() == 0:
raise ValueError("No valid ground-truth boxes available for one of the images")
return self[keep] I can submit a PR if you tell me which one is the preferred solution (if any). I personally prefer the first (maybe you know a better/faster way to check this). In case you need:
1There are 29 of these cases in the COCO2014 train dataset itself, but there are other bboxes in these images. The error happened to me because I was using a subset of the classes, which led these other bboxes to be discarded. |
Wow, that's some pretty nice investigation! If the underlying reason is that we have annotations with But I think there is a slightly simpler way of implementing it: why not just if all(any(o < 1 for o in obj["bbox"][2:]) for obj in anno if obj["iscrowd"] == 0):
to_remove.append(img_id) applies for image If you could send such a PR it would be awesome! |
@fmassa I've just submitted the PR with small modifications: just changed your suggestion to less than or equal to, removed the prints (I can bring them back using the logger, but it did not feel correct -- logger is not used in the class), and formatted the code using black. |
I've merged #396 |
@fmassa I think so. If anyone else catches another missing case, they can be handled on-demand in other issues. BTW, thanks for the amazing work you've been doing. 👏 |
Closing per @rodrigoberriel comment. Please if you still face similar issues just let us know |
❓ Questions and Help
Traceback (most recent call last):
File "github/maskrcnn-benchmark/tools/train_net.py", line 170, in
main()
File "github/maskrcnn-benchmark/tools/train_net.py", line 163, in main
model = train(cfg, args.local_rank, args.distributed)
File "github/maskrcnn-benchmark/tools/train_net.py", line 73, in train
arguments,
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 65, in do_train
loss_dict = model(images, targets)
File "/home/v-minghl/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 50, in forward
proposals, proposal_losses = self.rpn(images, features, targets)
File "/home/v-minghl/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 94, in forward
return self._forward_train(anchors, objectness, rpn_box_regression, targets)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 113, in _forward_train
anchors, objectness, rpn_box_regression, targets
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 91, in call
labels, regression_targets = self.prepare_targets(anchors, targets)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 62, in prepare_targets
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
RuntimeError: The shape of the mask [230202] at index 0 does not match the shape of the indexed tensor [0] at index 0
Hi,
I used my own data set in coco format to train. batchsize 2 and max iter is 15000. Train on P100 * 1.
It can succeed to run for 2540 iters and crashed with this error.
I succeed to train this data set with mmdetection.
I have no idea why the error happened, please give me some help.
The text was updated successfully, but these errors were encountered: