Skip to content

Commit

Permalink
[*] bcnet for maskrcnn: data, layers
Browse files Browse the repository at this point in the history
  • Loading branch information
trqminh committed Mar 2, 2022
1 parent 869450c commit 181a4da
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 2 deletions.
207 changes: 205 additions & 2 deletions detectron2/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,201 @@

logger = logging.getLogger(__name__)

__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances"]
__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances", "load_coco_json_eval"]

def load_coco_json_eval(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
"""
Load a json file with COCO's instances annotation format.
Currently supports instance detection, instance segmentation,
and person keypoints annotations.
Args:
json_file (str): full path to the json file in COCO instances annotation format.
image_root (str or path-like): the directory where the images in this json file exists.
dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
When provided, this function will also do the following:
* Put "thing_classes" into the metadata associated with this dataset.
* Map the category ids into a contiguous range (needed by standard dataset format),
and add "thing_dataset_id_to_contiguous_id" to the metadata associated
with this dataset.
This option should usually be provided, unless users need to load
the original json content and apply more processing manually.
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
"category_id", "segmentation"). The values for these keys will be returned as-is.
For example, the densepose annotations are loaded in this way.
Returns:
list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
`Using Custom Datasets </tutorials/datasets.html>`_ ) when `dataset_name` is not None.
If `dataset_name` is None, the returned `category_ids` may be
incontiguous and may not conform to the Detectron2 standard format.
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
"""
from pycocotools.coco import COCO

timer = Timer()
json_file = PathManager.get_local_path(json_file)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))

id_map = None
if dataset_name is not None:
meta = MetadataCatalog.get(dataset_name)
cat_ids = sorted(coco_api.getCatIds())
cats = coco_api.loadCats(cat_ids)
# The categories in a custom json file may not be sorted.
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
meta.thing_classes = thing_classes

# In COCO, certain category ids are artificially removed,
# and by convention they are always ignored.
# We deal with COCO's id issue and translate
# the category ids to contiguous ids in [0, 80).

# It works by looking at the "categories" field in the json, therefore
# if users' own json also have incontiguous ids, we'll
# apply this mapping as well but print a warning.
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
if "coco" not in dataset_name:
logger.warning(
"""
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
"""
)
id_map = {v: i for i, v in enumerate(cat_ids)}
meta.thing_dataset_id_to_contiguous_id = id_map

# sort indices for reproducible results
img_ids = sorted(coco_api.imgs.keys())
# imgs is a list of dicts, each looks something like:
# {'license': 4,
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
# 'file_name': 'COCO_val2014_000000001268.jpg',
# 'height': 427,
# 'width': 640,
# 'date_captured': '2013-11-17 05:57:24',
# 'id': 1268}
imgs = coco_api.loadImgs(img_ids)
# anns is a list[list[dict]], where each dict is an annotation
# record for an object. The inner list enumerates the objects in an image
# and the outer list enumerates over images. Example of anns[0]:
# [{'segmentation': [[192.81,
# 247.09,
# ...
# 219.03,
# 249.06]],
# 'area': 1035.749,
# 'iscrowd': 0,
# 'image_id': 1268,
# 'bbox': [192.81, 224.8, 74.73, 33.43],
# 'category_id': 16,
# 'id': 42986},
# ...]
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
total_num_valid_anns = sum([len(x) for x in anns])
total_num_anns = len(coco_api.anns)
if total_num_valid_anns < total_num_anns:
logger.warning(
f"{json_file} contains {total_num_anns} annotations, but only "
f"{total_num_valid_anns} of them match to images in the file."
)

if "minival" not in json_file:
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
# Therefore we explicitly white-list them.
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
json_file
)

imgs_anns = list(zip(imgs, anns))
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))

dataset_dicts = []

ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])

num_instances_without_valid_segmentation = 0

for (img_dict, anno_dict_list) in imgs_anns:
record = {}
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]

objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same as
# the image_id we're looking at.
# This fails only when the data parsing logic or the annotation file is buggy.

# The original COCO valminusminival2014 & minival2014 annotation files
# actually contains bugs that, together with certain ways of using COCO API,
# can trigger this assertion.
assert anno["image_id"] == image_id

assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'

obj = {key: anno[key] for key in ann_keys if key in anno}
if "bbox" in obj and len(obj["bbox"]) == 0:
raise ValueError(
f"One annotation of image {image_id} contains empty 'bbox' value! "
"This json does not have valid COCO format."
)

segm = anno.get("segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if isinstance(segm, dict):
if isinstance(segm["counts"], list):
# convert to compressed RLE
segm = mask_util.frPyObjects(segm, *segm["size"])
else:
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
obj["segmentation"] = segm

keypts = anno.get("keypoints", None)
if keypts: # list[int]
for idx, v in enumerate(keypts):
if idx % 3 != 2:
# COCO's segmentation coordinates are floating points in [0, H or W],
# but keypoint coordinates are integers in [0, H-1 or W-1]
# Therefore we assume the coordinates are "pixel indices" and
# add 0.5 to convert to floating point coordinates.
keypts[idx] = v + 0.5
obj["keypoints"] = keypts

obj["bbox_mode"] = BoxMode.XYWH_ABS
if id_map:
annotation_category_id = obj["category_id"]
try:
obj["category_id"] = id_map[annotation_category_id]
except KeyError as e:
raise KeyError(
f"Encountered category_id={annotation_category_id} "
"but this id does not exist in 'categories' of the json file."
) from e
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)

if num_instances_without_valid_segmentation > 0:
logger.warning(
"Filtered out {} instances without valid segmentation. ".format(
num_instances_without_valid_segmentation
)
+ "There might be issues in your dataset generation process. Please "
"check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
)
return dataset_dicts

def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
"""
Expand Down Expand Up @@ -178,6 +371,8 @@ def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_ke
)

segm = anno.get("segmentation", None)
bo_segm = anno.get("bg_object_segmentation", None)
i_segm = anno.get("i_segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if isinstance(segm, dict):
if isinstance(segm["counts"], list):
Expand All @@ -186,10 +381,14 @@ def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_ke
else:
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
bo_segm = [poly for poly in bo_segm if len(poly) % 2 == 0 and len(poly) >= 6]
i_segm = [poly for poly in i_segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
obj["segmentation"] = segm
obj["bg_object_segmentation"] = bo_segm
obj["i_segmentation"] = i_segm

keypts = anno.get("keypoints", None)
if keypts: # list[int]
Expand Down Expand Up @@ -497,7 +696,11 @@ def register_coco_instances(name, metadata, json_file, image_root):
assert isinstance(json_file, (str, os.PathLike)), json_file
assert isinstance(image_root, (str, os.PathLike)), image_root
# 1. register a function which returns dicts
DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
if 'train' in name:
DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
else:
DatasetCatalog.register(name, lambda: load_coco_json_eval(json_file, image_root, name))


# 2. Optionally, add metadata about this dataset,
# since they might be useful in evaluation, visualization or logging
Expand Down
38 changes: 38 additions & 0 deletions detectron2/data/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,38 @@ def transform_instance_annotations(
" COCO-style RLE as a dict.".format(type(segm))
)

if "i_segmentation" in annotation:
# each instance contains 1 or more polygons
i_segm = annotation["i_segmentation"]
if isinstance(i_segm, list):
# polygons
polygons = [np.asarray(p).reshape(-1, 2) for p in i_segm]
annotation["i_segmentation"] = [
p.reshape(-1) for p in transforms.apply_polygons(polygons)
]
elif isinstance(i_segm, dict):
# RLE
mask = mask_util.decode(i_segm)
mask = transforms.apply_segmentation(mask)
assert tuple(mask.shape[:2]) == image_size
annotation["i_segmentation"] = mask
else:
raise ValueError(
"Cannot transform segmentation of type '{}'!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict.".format(type(i_segm))
)

if "bg_object_segmentation" in annotation:
# each instance contains 1 or more polygons
bo_segm = annotation["bg_object_segmentation"]
if isinstance(bo_segm, list):
# polygons
bo_polygons = [np.asarray(p).reshape(-1, 2) for p in bo_segm]
annotation["bg_object_segmentation"] = [
p.reshape(-1) for p in transforms.apply_polygons(bo_polygons)
]

if "keypoints" in annotation:
keypoints = transform_keypoint_annotations(
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
Expand Down Expand Up @@ -398,9 +430,13 @@ def annotations_to_instances(annos, image_size, mask_format="polygon"):

if len(annos) and "segmentation" in annos[0]:
segms = [obj["segmentation"] for obj in annos]
bo_segms = [obj["bg_object_segmentation"] for obj in annos]
i_segms = [obj["i_segmentation"] for obj in annos]
if mask_format == "polygon":
try:
masks = PolygonMasks(segms)
bo_masks = PolygonMasks(bo_segms)
i_masks = PolygonMasks(i_segms)
except ValueError as e:
raise ValueError(
"Failed to use mask_format=='polygon' from the given annotations!"
Expand Down Expand Up @@ -433,6 +469,8 @@ def annotations_to_instances(annos, image_size, mask_format="polygon"):
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
)
target.gt_masks = masks
target.gt_bo_masks = bo_masks
target.gt_i_masks = i_masks

if len(annos) and "keypoints" in annos[0]:
kpts = [obj.get("keypoints", []) for obj in annos]
Expand Down
1 change: 1 addition & 0 deletions detectron2/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
from .aspp import ASPP
from .losses import ciou_loss, diou_loss
from .boundary import get_instances_contour_interior

__all__ = [k for k in globals().keys() if not k.startswith("_")]
60 changes: 60 additions & 0 deletions detectron2/layers/boundary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
#from PIL import Image #, ImageOps, ImageDraw
from skimage import filters, img_as_ubyte
from skimage.morphology import remove_small_objects, dilation, erosion, binary_dilation, binary_erosion, square
#from scipy.ndimage.interpolation import map_coordinates
#from scipy.ndimage.morphology import binary_fill_holes
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.measurements import center_of_mass
from PIL import Image, ImageDraw

def get_contour_interior(mask, bold=False):
if True: #'camunet' == config['param']['model']:
# 2-pixel contour (1out+1in), 2-pixel shrinked interior
outer = binary_dilation(mask) #, square(9))
if bold:
outer = binary_dilation(outer) #, square(9))
inner = binary_erosion(mask) #, square(9))
contour = ((outer != inner) > 0).astype(np.uint8)
interior = (erosion(inner) > 0).astype(np.uint8)
else:
contour = filters.scharr(mask)
scharr_threshold = np.amax(abs(contour)) / 2.
contour = (np.abs(contour) > scharr_threshold).astype(np.uint8)
interior = (mask - contour > 0).astype(np.uint8)
return contour, interior

def get_center(mask):
r = 2
y, x = center_of_mass(mask)
center_img = Image.fromarray(np.zeros_like(mask).astype(np.uint8))
if not np.isnan(x) and not np.isnan(y):
draw = ImageDraw.Draw(center_img)
draw.ellipse([x-r, y-r, x+r, y+r], fill='White')
center = np.asarray(center_img)
return center

def get_instances_contour_interior(instances_mask):
adjacent_boundary_only = False #False #config['contour'].getboolean('adjacent_boundary_only')
instances_mask = instances_mask.data
result_c = np.zeros_like(instances_mask, dtype=np.uint8)
result_i = np.zeros_like(instances_mask, dtype=np.uint8)
weight = np.ones_like(instances_mask, dtype=np.float32)
#masks = decompose_mask(instances_mask)
#for m in masks:
contour, interior = get_contour_interior(instances_mask, bold=adjacent_boundary_only)
#center = get_center(m)
if adjacent_boundary_only:
result_c += contour
else:
result_c = np.maximum(result_c, contour)
result_i = np.maximum(result_i, interior)
#contour += center
contour = np.where(contour > 0, 1, 0)
# magic number 50 make weight distributed to [1, 5) roughly
weight *= (1 + gaussian_filter(contour, sigma=1) / 50)
if adjacent_boundary_only:
result_c = (result_c > 1).astype(np.uint8)
return result_c, result_i, weight

0 comments on commit 181a4da

Please sign in to comment.