Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

RuntimeError: The shape of the mask [230202] at index 0 does not match the shape of the indexed tensor [0] at index 0 #31

Closed
liminghao1630 opened this issue Oct 26, 2018 · 34 comments

Comments

@liminghao1630
Copy link

❓ Questions and Help

Traceback (most recent call last):
File "github/maskrcnn-benchmark/tools/train_net.py", line 170, in
main()
File "github/maskrcnn-benchmark/tools/train_net.py", line 163, in main
model = train(cfg, args.local_rank, args.distributed)
File "github/maskrcnn-benchmark/tools/train_net.py", line 73, in train
arguments,
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 65, in do_train
loss_dict = model(images, targets)
File "/home/v-minghl/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 50, in forward
proposals, proposal_losses = self.rpn(images, features, targets)
File "/home/v-minghl/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in call
result = self.forward(*input, **kwargs)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 94, in forward
return self._forward_train(anchors, objectness, rpn_box_regression, targets)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 113, in _forward_train
anchors, objectness, rpn_box_regression, targets
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 91, in call
labels, regression_targets = self.prepare_targets(anchors, targets)
File "/home/v-minghl/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/loss.py", line 62, in prepare_targets
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
RuntimeError: The shape of the mask [230202] at index 0 does not match the shape of the indexed tensor [0] at index 0

Hi,
I used my own data set in coco format to train. batchsize 2 and max iter is 15000. Train on P100 * 1.
It can succeed to run for 2540 iters and crashed with this error.
I succeed to train this data set with mmdetection.
I have no idea why the error happened, please give me some help.

@fmassa
Copy link
Contributor

fmassa commented Oct 26, 2018

Hi,

This error usually happens when the learning rate is too high: the gradients explode, and the activations become inf, so that once we try to perform indexing it gets out of range.

That would be my first guess, but let me see in a bit more detail to see if it could be something else

@fmassa
Copy link
Contributor

fmassa commented Oct 26, 2018

I actually think that this error might come from a different place.

I have the impression that one of your images doesn't have a bounding box target. Is that right?

@fmassa
Copy link
Contributor

fmassa commented Oct 26, 2018

If that's the case, I'd recommend removing those images during training.
That's what we do in https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/datasets/coco.py#L19-L24, which is by default filtered during training.

@fmassa
Copy link
Contributor

fmassa commented Oct 26, 2018

I have been looking into mmdetection implementation, and they raise an error if there are no ground-truth boxes in the image, see here.
Could you verify that there is indeed at least one image where there are no ground-truth boxes?
I should probably add an error as well, instead of silently returning an empty tensor.

@fmassa
Copy link
Contributor

fmassa commented Oct 26, 2018

Hi,

Did you have the chance to verify if your dataset contains images without annotations?
I've just submitted a PR that should help catch those cases, see #37.

Let me know what you think.

@liminghao1630
Copy link
Author

Hi,
Sorry for the slow reply.
I have already checked the data set, and ensure there is no image without annotations in it.
I used 0.01 as a old base_lr and changed into 0.0025 for now but still crashed at 1960 iters.

@mveres01
Copy link

I had a similar error this that was caused by a segmentation started for an image but never finished (e.g. ann['segmentation'] = []). This doesn't get caught by #31 (comment)

A dirty workaround is something like:

if remove_images_without_annotations: 
    ...

out = []
for id in self.ids:
    ann_ids = self.coco.getAnnIds(imgIds=id)

    for ann in self.coco.loadAnns(ann_ids):
        if len(ann['segmentation']) > 0:
            out.append(id)
            break

self.ids = out

@liminghao1630
Copy link
Author

When I change the code as what you mentioned at #37 , the error get caught by "ValueError: No ground-truth boxes available for one of the images". Sorry for that, I think it may because I check my data set in pascal voc format but the code changing format from pascal voc to coco had some problem. Is it possible when the error was caught to print the filename of image?

@fmassa
Copy link
Contributor

fmassa commented Oct 26, 2018

I think it is possible to do so, but we would need to do it in the dataset class itself, somewhere near here https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/datasets/coco.py#L40
Something like

if not boxes:
    raise ValueError("Image id {} ({}) doesn't have annotations!".format(self.ids[idx], anno))

Let me know what happens. I'm looking into improving the error message for those corner cases.

@liminghao1630
Copy link
Author

@fmassa Hi
I have changed the code. But it was only caught by this error "ValueError: No ground-truth boxes available for one of the images".
By the way, when I use mmdetection and tensorflow OD API(faster rcnn), I can use batchsize 8 to train. When I use maskrcnn-benchmark(mask rcnn) I can only use batchsize 2. If changed bigger, for example batchsize 4, it will crash with "RuntimeError: CUDA error: out of memory".
I noticed the config file have no information about the resizer. So I guess maybe it was due to this?
MODEL: META_ARCHITECTURE: "GeneralizedRCNN" WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d" BACKBONE: CONV_BODY: "R-101-FPN" OUT_CHANNELS: 256 RPN: USE_FPN: True ANCHOR_STRIDE: (4, 8, 16, 32, 64) PRE_NMS_TOP_N_TRAIN: 2000 PRE_NMS_TOP_N_TEST: 1000 POST_NMS_TOP_N_TEST: 1000 FPN_POST_NMS_TOP_N_TEST: 1000 ROI_HEADS: USE_FPN: True ROI_BOX_HEAD: POOLER_RESOLUTION: 7 POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) POOLER_SAMPLING_RATIO: 2 FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" PREDICTOR: "FPNPredictor" ROI_MASK_HEAD: POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" PREDICTOR: "MaskRCNNC4Predictor" POOLER_RESOLUTION: 14 POOLER_SAMPLING_RATIO: 2 RESOLUTION: 28 SHARE_BOX_FEATURE_EXTRACTOR: False RESNETS: STRIDE_IN_1X1: False NUM_GROUPS: 32 WIDTH_PER_GROUP: 8 MASK_ON: True DATASETS: TRAIN: ("latex_5000", "latex_5000") TEST: ("latex_5000",) DATALOADER: SIZE_DIVISIBILITY: 32 SOLVER: BASE_LR: 0.0025 WEIGHT_DECAY: 0.0001 STEPS: (10000, 12000) MAX_ITER: 15000 IMS_PER_BATCH: 2

@fmassa
Copy link
Contributor

fmassa commented Oct 27, 2018

We by default train with images of min size 800 and max size 1333. You can modify it in INPUT.MIN_SIZE_TRAIN. Is that what you meant?

About the OOM errors, Do you have a very large number of instances in a single image in your dataset ( > 500)? Also, where do you get the OOM error? I think it might be related to the box IOU calculation

About the error not being caught earlier, that's weird, could you paste the part that you modified in the COCODataset that shows the extra lines, together with the surrounding context?

@liminghao1630
Copy link
Author

liminghao1630 commented Oct 27, 2018

Yeah.
Maybe that is the reason of out of memory.
In the tensorflow OD API, I use the 512 as the max size and pad to the max size.
The OOM errors appeared after these logs:
loading annotations into memory... Done (t=0.05s) creating index... index created! loading annotations into memory... Done (t=0.03s) creating index... index created! 2018-10-27 03:05:16,302 maskrcnn_benchmark.trainer INFO: Start training

And here is what I modified in COCODataset:
`

    def __getitem__(self, idx):
    img, anno = super(COCODataset, self).__getitem__(idx)
    # filter crowd annotations
    # TODO might be better to add an extra field
    anno = [obj for obj in anno if obj["iscrowd"] == 0]
    if not anno:
        raise ValueError("Image id {} ({}) doesn't have annotations!".format(self.ids[idx], anno))
    boxes = [obj["bbox"] for obj in anno]
    boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
    target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")

`

@fmassa
Copy link
Contributor

fmassa commented Oct 27, 2018

I had modified the initial message that I sent, check it again, there was a bug in the initial version that I posted and now it should show the image that is problematic. The new version is edited in GitHub in this thread

@zylhub
Copy link

zylhub commented Oct 30, 2018

I got the same issue too. When I check the data, I find that there is an empty box .

so, I ignored these data when training. untile now, everything is fine.

And here is what I modified https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/engine/trainer.py#L57

  for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        arguments["iteration"] = iteration
        # add ignore 
        if len(targets[0]) < 1:
            print('num_boxes: ', len(targets[0]))
            continue
        scheduler.step()

@fmassa
Copy link
Contributor

fmassa commented Oct 30, 2018

@zylhub thanks for the snippet!
My original idea was to make the dataset remove the empty boxes when in training mode, but it seems that my solution is not adapted to all cases.
Just to know, did you write your own dataset class, or are you using the COCODataset, and feeding your data in COCO format?

@zylhub
Copy link

zylhub commented Oct 31, 2018

@fmassa yes, I have write my own dataset class, and feeding in coco formats.

@fmassa
Copy link
Contributor

fmassa commented Oct 31, 2018

My assumption in this case was that the code I wrote in the __init__ of the COCODataset class would remove images with empty boxes.

I still need to understand why this isn't the case, So that I can make the code more robust. Skipping the batch is definitely a solution, but I'd rather avoid it for now

@xuanyuzhou98
Copy link

@fmassa I think the code at init is not enough to prevent that error. We also need a check after calling "clip_to_image" at getitem to further prevent empty bounding box.

@fmassa
Copy link
Contributor

fmassa commented Nov 1, 2018

@xuanyuzhou98 oh, that a very good point!
I'll add a check in the dataset loading as well, thanks!

@liminghao1630
Copy link
Author

@fmassa Hi, I found the reason which may cause the strange phenomenon!
I used the data for object detection directly in maskrcnn without segment annotations.

@fmassa
Copy link
Contributor

fmassa commented Nov 2, 2018

@liminghao1630 thanks for reaching back. Could you explain a bit more what was happening before, and maybe how we could the asserts more robust?

@jonasteuwen
Copy link

jonasteuwen commented Nov 2, 2018

I get a similar problem. However, I converted my dataset to mscoco, and I have both bbox and segmentations, but not all images have segmentations in which case I only use bbox'es. As a first try, I decided to run FasterRCNN (MODEL.MASK_ON = False), and in that case set in coco.py this to false:

        masks = [obj["segmentation"] for obj in anno]
        masks = SegmentationMask(masks, img.size)
        target.add_field("masks", masks)

Should that be sufficient? In any case, running FasterRCNN shouldn't need the segmentations.

As I think my problem is similar, I will investigate it a bit more, to see where it crashes.

Edit: I found that a couple of images have bounding boxes with negative coordinates through some faulty files during generation. Is this an error that can be caught?

@miguelvr
Copy link
Contributor

This was happening to me because some of my annotations were out of image bounds. Maybe it is worth it to have a check against this.

@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

@miguelvr were they completely out of the image, or just part of it?

@miguelvr
Copy link
Contributor

just part of it, sometimes it was even hard to notice with the drawn bboxes

@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

Interesting. I'm not yet sure where/why this would crash in the current code, I'd need to check a bit more.

@txytju
Copy link

txytju commented Dec 16, 2018

I have also met the problem and I'm sure all the images in my dataset have bboxes. I' tring to remove out-of-image-boundary bbox annotations and find out will that works.

@fmassa
Copy link
Contributor

fmassa commented Dec 17, 2018

@txytju let us know if that was what was necessary to fix your problem

@rodrigoberriel
Copy link
Contributor

@fmassa In my case, the problem was caused by some bounding boxes with width or height between [0, 1] (COCO dataset1). These are only removed after .clip_to_image(...). The problem arises at bounding_box.py#L85-L86, where ({w,h} - TO_REMOVE).clamp(min=0) sets xmin==xmax and/or ymin==ymax when the width and/or height of the bboxes are less than or equal to TO_REMOVE=1. The problem is that, later, clip_to_image(...) checks (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]), which becomes False in these cases.

My workaround was to add an extra check in the if remove_images_without_annotations, anticipating these cases. This is what I did (I added some prints because I always want to be aware of this fact):

# filter images without detection annotations
if remove_images_without_annotations:
    total_removed = 0
    len_ids_before = len(self.ids)
    self.ids = [
        img_id
        for img_id in self.ids
        if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
    ]
    total_removed += len_ids_before - len(self.ids)
    if total_removed > 0:
        print("{} images were removed because they do not have annotations!".format(total_removed))

    ids_to_remove = []
    for img_id in self.ids:
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anno = self.coco.loadAnns(ann_ids)
        anno = [obj for obj in anno if obj["iscrowd"] == 0]
        boxes = [obj["bbox"] for obj in anno]
        boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
        img_data = self.coco.loadImgs(img_id)[0]
        img_size = (img_data['width'], img_data['height'])  # try to speed up by avoid loading the image
        target = BoxList(boxes, img_size, mode="xywh").convert("xyxy")
        target = target.clip_to_image(remove_empty=True)
        if len(target) == 0:
            ids_to_remove.append(img_id)

    total_removed += len(ids_to_remove)
    if len(ids_to_remove) > 0:
        self.ids = [img_id for img_id in self.ids if img_id not in ids_to_remove]
        print("These {} images were removed after `clip_to_image`: {}".format(
            len(ids_to_remove), ', '.join(list(map(str, ids_to_remove)))
        ))

    if total_removed > 0:
        print('In total, {} images without annotations were removed'.format(total_removed))

If you prefer the current exception-oriented approach, one exception could be thrown here:

if remove_empty:
    box = self.bbox
    keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
    if keep.sum() == 0:
        raise ValueError("No valid ground-truth boxes available for one of the images")
    return self[keep]

I can submit a PR if you tell me which one is the preferred solution (if any). I personally prefer the first (maybe you know a better/faster way to check this).

In case you need:

  • Python 3.5
  • torch.__version__: 1.0.0.dev20190125
  • git log -n 1: 5f2a826

1There are 29 of these cases in the COCO2014 train dataset itself, but there are other bboxes in these images. The error happened to me because I was using a subset of the classes, which led these other bboxes to be discarded.

@fmassa
Copy link
Contributor

fmassa commented Jan 31, 2019

Hi @rodrigoberriel

Wow, that's some pretty nice investigation!

If the underlying reason is that we have annotations with width or height < 1 (which is probably annotation error btw), then I think I'd prefer the first approach where we remove those images beforehand.

But I think there is a slightly simpler way of implementing it: why not just loadAnns, and then remove the image if a condition like

if all(any(o < 1 for o in obj["bbox"][2:]) for obj in anno if obj["iscrowd"] == 0):
    to_remove.append(img_id)

applies for image img_id?

If you could send such a PR it would be awesome!

@rodrigoberriel
Copy link
Contributor

@fmassa I've just submitted the PR with small modifications: just changed your suggestion to less than or equal to, removed the prints (I can bring them back using the logger, but it did not feel correct -- logger is not used in the class), and formatted the code using black.

@fmassa
Copy link
Contributor

fmassa commented Jan 31, 2019

I've merged #396
@rodrigoberriel do you think we could close this issue?

@rodrigoberriel
Copy link
Contributor

@fmassa I think so. If anyone else catches another missing case, they can be handled on-demand in other issues.

BTW, thanks for the amazing work you've been doing. 👏

@fmassa
Copy link
Contributor

fmassa commented Jan 31, 2019

Closing per @rodrigoberriel comment. Please if you still face similar issues just let us know

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

9 participants