Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

How to train with negative samples of my training set? #169

Open
Igal20 opened this issue Nov 18, 2018 · 47 comments
Open

How to train with negative samples of my training set? #169

Igal20 opened this issue Nov 18, 2018 · 47 comments
Labels
question Further information is requested

Comments

@Igal20
Copy link

Igal20 commented Nov 18, 2018

❓ Questions and Help

Hello,
My data have only two classes first is background = 0 and object = 1.
While training I it necessary for me to present to the net negative samples i.e images without an object, just background. In this case I don't have segmentation contour. How do I add those images to the training?

I use COCO-style annotations for the images, saved in json format.
Thanks in advance.

@Igal20 Igal20 changed the title How to add negative samples for my training set? How to train with negative samples of my training set? Nov 18, 2018
@wjp0408
Copy link

wjp0408 commented Nov 19, 2018

@Igal20
Copy link
Author

Igal20 commented Nov 19, 2018

@wjp0408 Not so much, issue 144 handle with negative samples simply by filtering them and not including in training at all.
I do not want to exclude them but use them while training.

@fmassa
Copy link
Contributor

fmassa commented Nov 19, 2018

Hi

One image should contain two things: a positive element and a negative element.
This is always the case for detection, except if your bounding boxes are of the full size of the image, in which case you might be looking for object classification, and not object detection?

@fmassa fmassa added the question Further information is requested label Nov 19, 2018
@Igal20
Copy link
Author

Igal20 commented Nov 19, 2018

@fmassa Thanks for the answer.
If I understand right "One image should contain two things: a positive element and a negative element."
is true for evaluating the net not for training.

I'll elaborate a little bit more on what I aim to do.
For example I want to train Mask-rcnn to detect and segment giraffes. So I provide 1000 sample of images and masks of giraffes. But I have also 10 image that look like the animal [for example blanket with giraffes pattern] and I want to include them in my training process but obviously masks of those images will be blank.
Thus I want to be able to add images without segmentation mask.
Thanks

@fmassa
Copy link
Contributor

fmassa commented Nov 19, 2018

I see, thanks for the explanation!

It is possible to support your use-case, but you'll need to adapt a few things in the code for that.
The first issue you'll see is this one: #31
You can make the Matcher return a tensor of size N, but you'll afterwards probably face a few other issues down the road that will need to be addressed.

Let me know if you get stuck in a particular problem. It might be interesting to see, once the code is working, what it would take to make it work, so that we might potentially send a PR to add support for this in the codebase.

@Igal20
Copy link
Author

Igal20 commented Nov 19, 2018

Thanks I'll check it.

@Iwontbecreative
Copy link
Contributor

I have the same objective of being able to train on negative examples (images where there is nothing to detect).

My goal is both one of learning from negative data (the fact that there is nothing should at least help train the RPN) and calibration (most of my images have nothing, so training only on positive data tends to lead to overconfident predictions).

Would be great if this was supported, although I can see how it would make many things harder (e.g: sampling, dimensions, etc.)

@fmassa
Copy link
Contributor

fmassa commented Nov 20, 2018

I'd be willing to support this use-case, so if you have troubles getting the code to work for those images just let me know and I'd help you with that so that we could have a PR for it. There might be a few tricky things to tune though.

@LU4E
Copy link

LU4E commented Dec 10, 2018

I have the same objective of being able to train on negative examples (images where there is nothing to detect).

I have the same question.

@fmassa
Copy link
Contributor

fmassa commented Dec 10, 2018

@LU4E I still have the same answer as before :-)

@Iwontbecreative
Copy link
Contributor

I have some code for handling empty images but I will not get around to disentangling it from the rest of my code before Dec 19

@fmassa
Copy link
Contributor

fmassa commented Dec 10, 2018

@Iwontbecreative it would be a very nice addition!

Could you briefly summarize what were the things that you did that made it work nicely?

@Iwontbecreative
Copy link
Contributor

Iwontbecreative commented Dec 14, 2018

This is probably not the right form and not directly tested on the new codebase but should be the main changes I did.

https://pastebin.com/6xXEWtvg

@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

@Iwontbecreative thanks a lot for the patch!

Do you know if by applying this patch we get better results than by just removing all the images that do not have any label in it?

@Iwontbecreative
Copy link
Contributor

On my own dataset it seemed to have helped, but we have ~100 times as much unlabelled data (and it comes from a slightly different distribution, with our evaluation not being mainly about bounding box prediction).

I have not done any experiment on COCO sadly, sorry. Could be a configuration flag maybe?

@BobZhangHT
Copy link

BobZhangHT commented Jan 17, 2019

@Iwontbecreative Hi! Thanks for your awesome codes for the negative sample training. But after changing the codes as you suggested in https://pastebin.com/6xXEWtvg, I cannot train my model via multi-GPU now. (always stuck at the beginning of the training, shown as follows) I would like to inquire if you ever met such a problem before and @fmassa could you please give any suggestions for this? (I'm sure that I can use multi-GPU before, so it may not be the driver problem.)

2019-01-17 20:12:01,811 maskrcnn_benchmark.trainer INFO: Start training

@fmassa
Copy link
Contributor

fmassa commented Jan 17, 2019

@BobZhangHT If you changed your PyTorch version in between, that might potentially be the reason. Apart from that, I don't know

@BobZhangHT
Copy link

@fmassa Sincerely thanks for your reply. Actually I didn't change the pytorch version. I will try to figure it out and let you know if it gets resolved.

@Iwontbecreative
Copy link
Contributor

I have not run into this issue, not sure what is going wrong sorry...

@jgbos
Copy link

jgbos commented Jan 17, 2019

are the code changes available to see somewhere on github? Unfortunately I cannot access pastebin.com

@Iwontbecreative
Copy link
Contributor

Sorry, my codebase was quite different from commit at the time so I cherry picked the changes rather than a proper merge request.

diff --git a/maskrcnn_benchmark/modeling/matcher.py b/maskrcnn_benchmark/modeling/matcher.py
index 35ec5f1..074734c 100644
--- a/maskrcnn_benchmark/modeling/matcher.py
+++ b/maskrcnn_benchmark/modeling/matcher.py
@@ -53,9 +53,9 @@ class Matcher(object):
         if match_quality_matrix.numel() == 0:
             # empty targets or proposals not supported during training
             if match_quality_matrix.shape[0] == 0:
-                raise ValueError(
-                    "No ground-truth boxes available for one of the images "
-                    "during training")
+                length = match_quality_matrix.shape[-1]
+                device = match_quality_matrix.device
+                return torch.ones(length, dtype=torch.int64, device=device) * (-1)
             else:
                 raise ValueError(
                     "No proposal boxes available for one of the images "
diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
index 2c21f6c..bed4bbc 100644
--- a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
+++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
@@ -38,7 +38,11 @@ class FastRCNNLossComputation(object):
         # NB: need to clamp the indices because we can have a single
         # GT in the image, and matched_idxs can be -2, which goes
         # out of bounds
-        matched_targets = target[matched_idxs.clamp(min=0)]
+        if target.bbox.shape[0]:
+            matched_targets = target[matched_idxs.clamp(min=0)]
+        else:
+            target.add_field ("labels", matched_idxs.clamp(min=1, max=1))
+            matched_targets = target
         matched_targets.add_field("matched_idxs", matched_idxs)
         return matched_targets
 
@@ -63,9 +67,13 @@ class FastRCNNLossComputation(object):
             labels_per_image[ignore_inds] = -1  # -1 is ignored by sampler
 
             # compute regression targets
-            regression_targets_per_image = self.box_coder.encode(
-                matched_targets.bbox, proposals_per_image.bbox
-            )
+            if not matched_targets.bbox.shape[0]:
+                zeros = torch.zeros_like(labels_per_image, dtype=torch.float)
+                regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
+            else:
+                regression_targets_per_image = self.box_coder.encode(
+                    matched_targets.bbox, proposals_per_image.bbox
+                )
 
             labels.append(labels_per_image)
             regression_targets.append(regression_targets_per_image)
diff --git a/maskrcnn_benchmark/modeling/rpn/loss.py b/maskrcnn_benchmark/modeling/rpn/loss.py
index 0847231..a3dae25 100644
--- a/maskrcnn_benchmark/modeling/rpn/loss.py
+++ b/maskrcnn_benchmark/modeling/rpn/loss.py
@@ -43,7 +43,10 @@ class RPNLossComputation(object):
         # NB: need to clamp the indices because we can have a single
         # GT in the image, and matched_idxs can be -2, which goes
         # out of bounds
-        matched_targets = target[matched_idxs.clamp(min=0)]
+        if matched_idxs.clamp(min=0).sum() > 0:
+            matched_targets = target[matched_idxs.clamp(min=0)]
+        else:
+            matched_targets = target
         matched_targets.add_field("matched_idxs", matched_idxs)
         return matched_targets
 
@@ -55,6 +58,7 @@ class RPNLossComputation(object):
                 anchors_per_image, targets_per_image
             )
 
+
             matched_idxs = matched_targets.get_field("matched_idxs")
             labels_per_image = matched_idxs >= 0
             labels_per_image = labels_per_image.to(dtype=torch.float32)
@@ -66,9 +70,13 @@ class RPNLossComputation(object):
             labels_per_image[inds_to_discard] = -1
 
             # compute regression targets
-            regression_targets_per_image = self.box_coder.encode(
-                matched_targets.bbox, anchors_per_image.bbox
-            )
+            if not matched_targets.bbox.shape[0]:
+                zeros = torch.zeros_like(labels_per_image)
+                regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1)
+            else:
+                regression_targets_per_image = self.box_coder.encode(
+                    matched_targets.bbox, anchors_per_image.bbox
+                )
 
             labels.append(labels_per_image)
             regression_targets.append(regression_targets_per_image)
@@ -95,6 +103,8 @@ class RPNLossComputation(object):
 
         sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
 
+
+
         objectness_flattened = []
         box_regression_flattened = []
         # for each feature level, permute the outputs to make them be in the

@BobZhangHT
Copy link

@Iwontbecreative Thank you! It seems to be my codes problem.

@jgbos
Copy link

jgbos commented Jan 18, 2019

@Iwontbecreative i take it you updated the BoxList class as well to allow for empty box arrays?

@jgbos
Copy link

jgbos commented Jan 18, 2019

@Iwontbecreative i take it you updated the BoxList class as well to allow for empty box arrays?

@Iwontbecreative apologies for not fully debugging before commenting. Just needed to make sure the box input is np.zeros((0, 4)). Code is running without error now.

@fmassa
Copy link
Contributor

fmassa commented Jan 25, 2019

@IssamLaradji I do it in the initialization of the dataset

if remove_images_without_annotations:
self.ids = [
img_id
for img_id in self.ids
if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
]

@IssamLaradji
Copy link

thanks a lot @fmassa !

@jgbos
Copy link

jgbos commented Jan 26, 2019

I have successfully used the method above to include negative images (and updated for the mask head as well). I don't have lots of data, and the negative images are important.

thanks to @fmassa for a great repo.

@fmassa
Copy link
Contributor

fmassa commented Jan 28, 2019

@jgbos great, thanks for the information!

I'll keep a note of it, it seems that merging a patch with this fix would make things work out fine in many cases, so it would be a great addition!

@botcs
Copy link
Contributor

botcs commented Mar 5, 2019

Hi @fmassa, @Iwontbecreative,

based on this comment and @jgbos's success, it seems that submitting a PR of this patch would be quite useful.

@AdanMora, could you help making a unit test for this patch?

@fmassa
Copy link
Contributor

fmassa commented Mar 5, 2019

@botcs I agree, if someone could send a PR with unit tests I'll be more than happy to merge it!

@AdanMora
Copy link

AdanMora commented Mar 5, 2019

@botcs @fmassa Excellent, I'll take a look and try to get some unit test.

@dingguo1996
Copy link

@botcs @fmassa Excellent, I'll take a look and try to get some unit test.

could you also add the metric mAP for the negative samples when testing?

@botcs
Copy link
Contributor

botcs commented Mar 6, 2019

@qq237942920 nyeh. Sounds like a simple task but I have a hunch that you cannot compute mAP for negative samples. I mean, you iterate over different IoUs Precision and Recall values when evaluating AP for a single class. What you could do is to compute F1 or Prec/Rec with the following classification:
[Positive: has prediction][TruePostive: has annotaion and has prediction]...

(or one extremely ugly trick could be annotating the background as an object category, and make a Dataset with this approach, but I have never said this)

@dingguo1996
Copy link

dingguo1996 commented Mar 7, 2019

@qq237942920 nyeh. Sounds like a simple task but I have a hunch that you cannot compute mAP for negative samples. I mean, you iterate over different IoUs Precision and Recall values when evaluating AP for a single class. What you could do is to compute F1 or Prec/Rec with the following classification:
[Positive: has prediction][TruePostive: has annotaion and has prediction]...

(or one extremely ugly trick could be annotating the background as an object category, and make a Dataset with this approach, but I have never said this)

Thank you for your patient reply! Maybe I didn't express myself well in the last comment. Computing the mAP for the negative samples is mean that when my dataset both have positive and negative samples, will the coco_eval eval the mAP including fp on the negative samples(no gt in the img)? The trick you suggested in the end may be a way to compute them,

@tpys
Copy link

tpys commented Mar 9, 2019

@Iwontbecreative
the code matched_idxs.clamp(min=0).sum() > 0 in maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py is not correct, when there are only one fg, matched_idxs.clamp(min=0).sum() equal to zero, cause the matched target idx is 0.

@cltdevelop
Copy link

how to assign category_id for negative samples images, is it should be 0?

@cltdevelop
Copy link

@Iwontbecreative @fmassa Can you write detailed process about how to train model on negative samples ? Thanks!

@fmassa
Copy link
Contributor

fmassa commented Mar 12, 2019

@cltdevelop all the information that I know about I've already written here, but if @Iwontbecreative has time and wants to send a PR adding a more detailed information, I'd be happy to merge it

@Iwontbecreative
Copy link
Contributor

Iwontbecreative commented Mar 13, 2019

I don't have a lot of time to commit to this as of now unfortunately. I also have never experimented on traditional datasets and the one I did my experiments on is not public so it'd be tricky to give concrete results/numbers. I hope to have some more time to assess when it can be useful on traditional datasets at some point and come up with a proper PR.

@botcs
Copy link
Contributor

botcs commented Mar 13, 2019

@fmassa @Iwontbecreative I have prepared a synthetic DebugDataset which I use for unit tests. It just puts random number of white boxes on a black plane. It could be also used, as is, for sanity checking of your implementation.

@shimen
Copy link

shimen commented Mar 17, 2019

@tpys
You are right!!

@Iwontbecreative
the code matched_idxs.clamp(min=0).sum() > 0 in maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py is not correct, when there are only one fg, matched_idxs.clamp(min=0).sum() equal to zero, cause the matched target idx is 0.

it should be something like this:

if len(target):
       matched_targets = target[matched_idxs.clamp(min=0)]
else:
       matched_targets = target

@jgbos
Copy link

jgbos commented May 1, 2019

Has anyone run into troubles with the fix above using the latest master? I'm getting an error that a number of the proposals provided to the mask head now have zero width or height which causes the code to crash on this line

scaled_mask = cropped_mask.resize((M, M))

Where a divide by zero error happens on this line

ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))

krumo referenced this issue May 10, 2019
* support RLE and binary mask

* do not convert to numpy

* be consistent with Detectron

* delete wrong comment

* [WIP] add tests for segmentation_mask

* update tests

* minor change

* Refactored segmentation_mask.py

* Add unit test for segmentation_mask.py

* Add RLE support for BinaryMaskList

* PEP8 black formatting

* Minor patch

* Use internal  that handles 0 channels

* Fix polygon slicing
@ad12
Copy link

ad12 commented Oct 11, 2019

Thank you for the details!

In the most recent state of the code, is it sufficient to initialize the BBox list with 0s and set the label to 0?

For example, in maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/coco.py adding the following check:

    def __getitem__(self, idx):
        img, anno = super(COCODataset, self).__getitem__(idx)
        
        if not anno:
            boxes = [[0, 0, 0, 0]]
            boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
            target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
            target.add_field("labels",torch.tensor([0]))
        else:
            # filter crowd annotations
            # TODO might be better to add an extra field
            anno = [obj for obj in anno if obj["iscrowd"] == 0]

            boxes = [obj["bbox"] for obj in anno]
            boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
            target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")

            classes = [obj["category_id"] for obj in anno]
            classes = [self.json_category_id_to_contiguous_id[c] for c in classes]
            classes = torch.tensor(classes)
            target.add_field("labels", classes)

        target = target.clip_to_image(remove_empty=False)

        if self._transforms is not None:
            img, target = self._transforms(img, target)

        return img, target, idx

@Shankarram2709
Copy link

For @Igal20 question I think for a class of just 1-object and 0-background in terms of binary segmentation, using a no groundtruth(mask) or segmentation contour in few training samples will simply affect the accuracy of your training since semantic segmentation is performed pixelwise and considering your target. If you have nothing of positive pixels in your target then during backpropogation the weights of the learned features will be reversed back to non learned weights. (i.e. for example if you have equal amount of positive targets(ground truth with segmentation contours) and negative targets(ground truth with no segmentation targets) the features learned during the positve targets will be reversed back if you feed in also the negative targets in each backprop). So may be in classification it is actually needed but in segmentation the only way is to avoid it.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests