Skip to content

Commit

Permalink
add mutilabel coco config
Browse files Browse the repository at this point in the history
  • Loading branch information
Noietch committed Nov 11, 2024
1 parent 1d59947 commit 35af0b7
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 6 deletions.
3 changes: 2 additions & 1 deletion exps/muti_label/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dataset import LVISDataset
from .dataset import LVISDataset, COCODatasets

from .model import CLIPModel
from .metrics import MutiLabelMetric
26 changes: 26 additions & 0 deletions exps/muti_label/configs/clip_config_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_base_ = ['base.py']

val_dataset = dict(
type='COCODatasets',
data_root='data/coco2017',
data_prefix=dict(
img_path='val2017',
),
ann_file='annotations/instances_val2017.json',
pipeline=[
dict(type='LoadImage'),
dict(type='CLIPTransforms'),
dict(type='PackData'),
],
)

val_dataloader = dict(
batch_size=32,
dataset=val_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate')
)

model = dict(
type='CLIPModel',
)
50 changes: 50 additions & 0 deletions exps/muti_label/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from PIL import Image
from lvis.lvis import LVIS
from pycocotools.coco import COCO

from mmengine.dataset import BaseDataset
from mmengine.registry import DATASETS, TRANSFORMS
Expand Down Expand Up @@ -58,6 +59,55 @@ def load_data_list(self) -> list[dict]:
del self.lvis

return data_list

@DATASETS.register_module()
class COCODatasets(BaseDataset):

def parse_data_info(self, raw_data_info):
raw_ann_info = raw_data_info['raw_ann_info']
raw_img_info = raw_data_info['raw_img_info']

# print(raw_img_info)
# print([cur_cates[ann['category_id']-1] for ann in raw_ann_info])
# to one-hot
category_ids = torch.unique(torch.tensor([self.cat2label[ann['category_id']] for ann in raw_ann_info]))
cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0)

return {
"img_path": os.path.join(self.data_prefix['img_path'], raw_img_info['file_name']),
"gt_label": cate_one_hot,
}

def load_data_list(self) -> list[dict]:

self.coco = COCO(self.ann_file)
self.cat_ids = self.coco.getCatIds(catNms=cur_cates)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
img_ids = self.coco.getImgIds()
data_list = []

for img_id in img_ids:
raw_img_info = self.coco.loadImgs([img_id])[0]
raw_img_info['img_id'] = img_id

ann_ids = self.coco.getAnnIds(imgIds=[img_id])
raw_ann_info = self.coco.loadAnns(ann_ids)

if len(raw_ann_info) == 0:
# print(f"Image {img_id} has no annotations, skipped.")
continue

parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)

del self.coco

return data_list

@TRANSFORMS.register_module()
class LoadImage:
Expand Down
22 changes: 20 additions & 2 deletions exps/muti_label/globals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cur_cates = [
lvis_cates = [
'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
Expand Down Expand Up @@ -241,4 +241,22 @@
'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
'yoke_(animal_equipment)', 'zebra', 'zucchini'
]
]

coco_cates = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

cur_cates = coco_cates
4 changes: 1 addition & 3 deletions exps/muti_label/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import todd

from mmengine.runner import Runner
from mmengine import Config
from mmengine.registry import RUNNERS
Expand All @@ -8,6 +6,6 @@
import exps.muti_label

if __name__ == "__main__":
config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config.py")
config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config_coco.py")
runner: Runner = RUNNERS.build(config)
runner.val()

0 comments on commit 35af0b7

Please sign in to comment.