-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Support instance mask annotation with mask.png #256
Comments
There is an open PR that adds support for binary masks in #150 You can try it out to see if it works for your use-case. I still haven't had the time to pull it down and try it out for myself though, that's why I haven't merged the PR yet. |
OK, I will try it today and if it works I will report it here. The binary mask is a better form in my opinion rather than polygons and I think it should be the default form of instance mask. |
I have tested the code and have written a corresponding dataloader for image-mask input data. This dataloader is modified from COCODataset. Would you like to help and check it? especially the corresponding relationship between image information(like size) and image. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import numpy as np
import torch
import torchvision
from PIL import Image
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
class COCODatasetBinaryMask(torchvision.datasets.coco.CocoDetection):
def __init__(
self, ann_file, root, transforms=None
):
super(COCODatasetBinaryMask, self).__init__(root, ann_file)
# sort indices for reproducible results
self.ids = sorted(self.ids)
self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
}
self.contiguous_category_id_to_json_id = {
v: k for k, v in self.json_category_id_to_contiguous_id.items()
}
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
self.transforms = transforms
self.root = root
self.image_root = self.root + "images/"
self.mask_root = self.root + "masks/"
image_names = [image_name.split(".")[0] for image_name in os.listdir(self.image_root) if ".jpg" in image_name]
mask_names = [mask_name.replace("_mask","").split(".")[0] for mask_name in os.listdir(self.mask_root) if ".png" in mask_name]
self.names = list(set(image_names) & set(mask_names))
def __getitem__(self, idx):
name = self.names[idx]
image_path = self.image_root + name + ".jpg"
mask_path = self.mask_root + name + "_mask.png"
img = Image.open(image_path)
mask = np.array(Image.open(mask_path))
boxes, masks = self._get_insts_bbox_mask_from_mask(mask, third_object_color="red")
# boxes : a list of list [[x,y,w,h],[x,y,w,h],[...],[...],]
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")
classes = [1] * len(boxes) # only one class in my dataset
classes = torch.tensor(classes)
target.add_field("labels", classes)
# masks : list of numpy array
masks = SegmentationMask(masks, img.size)
target.add_field("masks", masks)
target = target.clip_to_image(remove_empty=True)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target, idx
def get_img_info(self, index):
img_id = self.id_to_img_map[index]
img_data = self.coco.imgs[img_id]
print(img_data)
return img_data
def _get_insts_bbox_mask_from_mask(self, mask, third_object_color="red"):
colors = np.unique(mask.reshape(-1, mask.shape[2]), axis=0)
colors = [list(color) for color in colors]
if third_object_color=="red":
abandon_colors = [[0, 0, 0], [0, 0, 255]]
elif third_object_color=="pink":
abandon_colors = [[0, 0, 0], [237, 199, 244]] # pink as the 3rd object
inst_colors = [color for color in colors if color not in abandon_colors]
boxes = []
masks = []
for i in range(len(inst_colors)):
inst_mask = np.all(np.equal(mask, inst_colors[i]), axis=2)
inst_mask = np.where(inst_mask==True, 1, 0)
inst_mask = inst_mask.astype(np.uint8)
# kernel_open = cv2.getStructuringElement(cv2.MORPH_RECT,(7, 7))
# inst_mask = cv2.morphologyEx(inst_mask, cv2.MORPH_CLOSE, kernel_open)
# kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT,(7, 7))
# inst_mask = cv2.morphologyEx(inst_mask, cv2.MORPH_CLOSE, kernel_close)
box = self._bbox(inst_mask)
box_area = self._area(box)
if box_area >= 100 :
y_min,y_max,x_min,x_max = box
boxes.append([x_min, y_min, x_max-x_min, y_max-y_min])
masks.append(inst_mask)
return boxes, masks
def _bbox(self, img):
a = np.where(img != 0)
bbox = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1])
return bbox # y_min,y_max,x_min,x_max
def _area(self, box):
return (box[1]-box[0]) * (box[3]-box[2]) |
According to https://github.com/facebookresearch/maskrcnn-benchmark/pull/150/files#diff-928af5178eceaef7d662fe22c85f439aR209, because in your case Here is one thing I'd do to verify that the code works as expected (without any transform in the dataset):
|
I will try to verify that, thanks! What's more, because
|
Oh, I missed the |
I have verified transpose operation, it works. |
Cool! Now if the rest of the training works out of the box with your dataset, then that is a good signal that we can be looking again into merging that PR |
Yes. I will generate offline dataset(instance masks) and train on my dataset. After that, if both polygon annotation and binary mask annotation work, maybe we should consider merging that PR. I will report the training result within 24 hours. |
You can increase the number of worker threads in the dataloader, so that you don't need to generate it offline - it will probably be simpler |
OK. I have implemented both online and offline method and using online method currently. I tried to overfit a large model on only 2 images, but the result is not that good. The predicted mask seems to have shifted a few pixels to the right compared with the ground_truth mask, as you can see in these images. I have no idea what's wrong here. |
This might indicate that there are still a few problems with the current implementation in the |
Thanks, I will try that! |
If the |
I tried to transpose the BoxList twice and it turns out it gives the original result. from maskrcnn_benchmark.data.datasets.coco_binary_mask import COCODatasetBinaryMaskOnLine
ann_file = "path/to/data_binary_mask.json"
root = "path/to/dataset"
coco_binary_mask = COCODatasetBinaryMaskOnLine(ann_file, root, transforms=None)
img, target, _ = coco_binary_mask[1] # index=1 for example
masks = target.get_field('masks').masks
f_f_target = target.transpose(0).transpose(0)
f_f_masks = f_f_target.get_field('masks').masks
# fliped twice masks
mask_1 = f_f_masks[0].mask.numpy()
mask_2 = f_f_masks[1].mask.numpy()
# original masks
mask_3 = masks[0].mask.numpy()
mask_4 = masks[1].mask.numpy()
print(np.all(np.equal(mask_1,mask_3)))
print(np.all(np.equal(mask_1,mask_4)))
print(np.all(np.equal(mask_2,mask_3)))
print(np.all(np.equal(mask_2,mask_4)))
# result is two `True` and two `False`, which means that the fliped masks are equal to original ones. What's more, I tried to turn off data agumentation but the trained model still predict a shifted mask.
|
@fmassa The only difference that I can see is |
So, one thing to also take into account is that the transforms rescale the image during training so that they have a particular size. This downsampling ( |
I think the |
The current implementation of Polygons here follow the implementation of Detectron, which is a legacy behavior which adds a |
@fmassa Shift problem in my dataset has been solved by just use I don't know exactlly why hat happens... |
I think nearest interpolation might bring those artifacts, but great to know that this was the solution for your case, it's very helpful! |
I trained my model using a small dataset that contains 2 images and it succeeds. But when I try to train on a larger dataset, a
I'm using one GPU and one image per batch. By the way, I am using offline dataset.
|
The reason is that you probably have a lot of GT per image. |
Thanks for that. One thing that I want to make sure is that OOM is not caused by loading all data before training in the Dataloader, right? Is it true that we just load batch images and targets when training? I think there will not be more than 5 instances in per image and both gpu memory and cpu memory are high. |
The OOM is happening on the GPU, so it's probably not related to data loading I believe. |
Update : I set I encountered OOM when loading data, as follows.
|
I have carried a few experiments. When I use dataset of only a few images(like 2 or 5 images), training works well. However when I try to train on a large dataset, there is always an GPU OOM problem. So I think that maybe that's not because of too many instances in an image, but that is because I saved some data about the whole dataset in GPU memory which gives me GPU OOM when using large dataset. Could you please give me some hint where to debug? thanks in advance! |
Depending on how you stored the data in the dataset (for example a numpy array), each worker will copy the whole object to each new thread, making it require a lot of CPU memory. If those are torch tensors, you should be fine. I'd recommend not loading the full dataset in memory, but instead load it at every getitem call |
Thanks for your patience! It turns out that the instances in an image are too many, which is the result of bad image labeling. After removing those images and label, the training works well. |
Awesome, thanks! So, to summarize, the only thing that you had to change in order for your training to work as expected (on top of the PR adding better mask support) is to change the interpolation mode to bilinear instead of nearest, is that right? |
Sorry to be late. Yes, after the change of interpolation, it works for me. However, because when operating resize and crop operations, you can only use int type when using binary mask, rather than float type when using polygon. And if the mask resolution is small(like 28), the ground truth mask based on binary mask is not as good as polygons. |
@txytju sounds good, thanks for the information! |
Could you merge your code? I think your work is awesome. |
@JoyHuYY1412 I'm willing to merge the PR that adds support for it, I'd just ask for it to have unit tests so that we know we are computing the same things for polygons and masks |
I could help with the unit tests, this PR would be useful for many I think |
@botcs yes, please! |
So, could you please make a list of tests that should be done? another thing: if this module is merged, than can we use the evaluation tools, |
Here are a few tests I think would be useful to have:
If support for Let me know if you have further questions! |
Okay, I still have a few questions, just for clarification:
[Edit]: I will continue this thread at #150 and will get back to this when unit tests are done |
@fmassa @botcs I think it is hard to make the behavior of binary mask and polygons exactly the same apart from the old detectron inconsistency. PR #150 alone may be not enough for binary masks to work as well as polygons, since this codebase was optimized for polygons based input. A possible good practice would be trying PR #150 and making some necessary modifications while using binary masks to make the coco performance as well as using polygons (e.g., by adding a global config flag to alter betwwen these two modes). You may refer to mmdetection for the necessary changes since it inherently utilizes binary masks. |
Thanks for your comments @wangg12 ! |
hello, |
🚀 Feature
Instances mask image can be used as the ground_truth label.
For example, in the PNG file, every instance is labeled using a unique color.
Motivation
Currently, annotations for instances is COCO-style, in which instance mask is annotated by polygons. However, if instance mask has holes, the polygon annotation fails.
But if we use a binary instance mask PNG, it can handle holes in the instance masks.
The text was updated successfully, but these errors were encountered: