Skip to content

Commit

Permalink
feat(opendataset): add "SemanticMask" for "BDD100K_MOTS2020" dataset
Browse files Browse the repository at this point in the history
PR Closed: #1102
  • Loading branch information
marshallmallows committed Nov 22, 2021
1 parent 5d5cf30 commit 00427b7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
13 changes: 13 additions & 0 deletions tensorbay/opendataset/BDD100K_MOT2020/catalog_mots.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@
}
]
},
"SEMANTIC_MASK": {
"categories": [
{ "name": "background", "categoryId": 0 },
{ "name": "pedestrian", "categoryId": 1 },
{ "name": "rider", "categoryId": 2 },
{ "name": "car", "categoryId": 3 },
{ "name": "truck", "categoryId": 4 },
{ "name": "bus", "categoryId": 5 },
{ "name": "train", "categoryId": 6 },
{ "name": "motorcycle", "categoryId": 7 },
{ "name": "bicycle", "categoryId": 8 }
]
},
"INSTANCE_MASK": {
"categories": [{ "name": "background", "categoryId": 0 }],
"attributes": [
Expand Down
37 changes: 25 additions & 12 deletions tensorbay/opendataset/BDD100K_MOT2020/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np

from tensorbay.dataset import Data, Dataset
from tensorbay.label import InstanceMask, LabeledBox2D, LabeledMultiPolygon
from tensorbay.label import InstanceMask, LabeledBox2D, LabeledMultiPolygon, SemanticMask
from tensorbay.opendataset._utility import glob

try:
Expand Down Expand Up @@ -187,7 +187,10 @@ def _generate_data(
if tracking_type == "mots":
original_mask_subdir = os.path.join(original_mask_dir, subdir_name)
mask_subdir = os.path.join(mask_dir, subdir_name)
os.makedirs(mask_subdir, exist_ok=True)
semantic_subdir = os.path.join(mask_subdir, "semantic")
instance_subdir = os.path.join(mask_subdir, "instance")
os.makedirs(semantic_subdir, exist_ok=True)
os.makedirs(instance_subdir, exist_ok=True)
with open(os.path.join(segment_labels_dir, f"{subdir_name}.json"), "r", encoding="utf-8") as fp:
label_contents = json.load(fp)
for label_content in label_contents:
Expand All @@ -200,9 +203,10 @@ def _generate_data(
) if tracking_type == "mot" else _get_mots_data(
image_path,
original_mask_subdir,
mask_subdir,
semantic_subdir,
instance_subdir,
os.path.splitext(label_content_name)[0],
label_content,
label_content=label_content,
)


Expand Down Expand Up @@ -231,8 +235,10 @@ def _get_mot_data(image_path: str, label_content: Dict[str, Any]) -> Data:
def _get_mots_data(
image_path: str,
original_mask_subdir: str,
mask_subdir: str,
semantic_subdir: str,
instance_subdir: str,
stem: str,
*,
label_content: Dict[str, Any],
) -> Data:
data = Data(image_path)
Expand All @@ -248,25 +254,28 @@ def _get_mots_data(
)
labeled_multipolygons.append(labeled_multipolygon)

mask_path = os.path.join(mask_subdir, f"{stem}.png")
semantic_path = os.path.join(semantic_subdir, f"{stem}.png")
instance_path = os.path.join(instance_subdir, f"{stem}.png")
mask_info = _save_and_get_mask_info(
os.path.join(original_mask_subdir, f"{stem}.png"),
mask_path,
os.path.join(mask_subdir, f"{stem}.json"),
semantic_path,
instance_path,
os.path.join(instance_subdir, f"{stem}.json"),
)
ins_mask = InstanceMask(mask_path)
ins_mask = InstanceMask(instance_path)
ins_mask.all_attributes = mask_info["all_attributes"]

label = data.label
label.multi_polygon = labeled_multipolygons
label.semantic_mask = SemanticMask(semantic_path)
label.instance_mask = ins_mask
return data


def _save_and_get_mask_info(
original_mask_path: str, mask_path: str, mask_info_path: str
original_mask_path: str, semantic_path: str, instance_path: str, mask_info_path: str
) -> Dict[str, Any]:
if not os.path.exists(mask_path):
if not os.path.exists(instance_path):
mask = np.array(Image.open(original_mask_path), dtype=np.uint16)
all_attributes = {}
for _, attributes, instance_id_high, instance_id_low in np.unique(
Expand All @@ -283,8 +292,12 @@ def _save_and_get_mask_info(
mask_info = {"all_attributes": all_attributes}
with open(mask_info_path, "w", encoding="utf-8") as fp:
json.dump(mask_info, fp)
Image.fromarray(mask[:, :, -1] + (mask[:, :, -2] << 8)).save(mask_path)
Image.fromarray(mask[:, :, -1] + (mask[:, :, -2] << 8)).save(instance_path)
if not os.path.exists(semantic_path):
Image.fromarray(mask[:, :, 0]).save(semantic_path)
else:
if not os.path.exists(semantic_path):
Image.fromarray(np.array(Image.open(original_mask_path))[:, :, 0]).save(semantic_path)
with open(mask_info_path, "r", encoding="utf-8") as fp:
mask_info = json.load(
fp,
Expand Down

0 comments on commit 00427b7

Please sign in to comment.