diff --git a/exps/muti_label/__init__.py b/exps/muti_label/__init__.py index d9aaff7..c2dfc53 100644 --- a/exps/muti_label/__init__.py +++ b/exps/muti_label/__init__.py @@ -1,3 +1,4 @@ -from .dataset import LVISDataset +from .dataset import LVISDataset, COCODatasets + from .model import CLIPModel from .metrics import MutiLabelMetric \ No newline at end of file diff --git a/exps/muti_label/configs/clip_config_coco.py b/exps/muti_label/configs/clip_config_coco.py new file mode 100644 index 0000000..9acc5fe --- /dev/null +++ b/exps/muti_label/configs/clip_config_coco.py @@ -0,0 +1,26 @@ +_base_ = ['base.py'] + +val_dataset = dict( + type='COCODatasets', + data_root='data/coco2017', + data_prefix=dict( + img_path='val2017', + ), + ann_file='annotations/instances_val2017.json', + pipeline=[ + dict(type='LoadImage'), + dict(type='CLIPTransforms'), + dict(type='PackData'), + ], +) + +val_dataloader = dict( + batch_size=32, + dataset=val_dataset, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate') +) + +model = dict( + type='CLIPModel', +) \ No newline at end of file diff --git a/exps/muti_label/dataset.py b/exps/muti_label/dataset.py index d61e44e..e420038 100644 --- a/exps/muti_label/dataset.py +++ b/exps/muti_label/dataset.py @@ -4,6 +4,7 @@ import numpy as np from PIL import Image from lvis.lvis import LVIS +from pycocotools.coco import COCO from mmengine.dataset import BaseDataset from mmengine.registry import DATASETS, TRANSFORMS @@ -58,6 +59,55 @@ def load_data_list(self) -> list[dict]: del self.lvis return data_list + +@DATASETS.register_module() +class COCODatasets(BaseDataset): + + def parse_data_info(self, raw_data_info): + raw_ann_info = raw_data_info['raw_ann_info'] + raw_img_info = raw_data_info['raw_img_info'] + + # print(raw_img_info) + # print([cur_cates[ann['category_id']-1] for ann in raw_ann_info]) + # to one-hot + category_ids = torch.unique(torch.tensor([self.cat2label[ann['category_id']] for ann in raw_ann_info])) + cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0) + + return { + "img_path": os.path.join(self.data_prefix['img_path'], raw_img_info['file_name']), + "gt_label": cate_one_hot, + } + + def load_data_list(self) -> list[dict]: + + self.coco = COCO(self.ann_file) + self.cat_ids = self.coco.getCatIds(catNms=cur_cates) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + img_ids = self.coco.getImgIds() + data_list = [] + + for img_id in img_ids: + raw_img_info = self.coco.loadImgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + raw_ann_info = self.coco.loadAnns(ann_ids) + + if len(raw_ann_info) == 0: + # print(f"Image {img_id} has no annotations, skipped.") + continue + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + + del self.coco + + return data_list @TRANSFORMS.register_module() class LoadImage: diff --git a/exps/muti_label/globals.py b/exps/muti_label/globals.py index 63a2284..cdf5d80 100644 --- a/exps/muti_label/globals.py +++ b/exps/muti_label/globals.py @@ -1,4 +1,4 @@ -cur_cates = [ +lvis_cates = [ 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', @@ -241,4 +241,22 @@ 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini' -] \ No newline at end of file +] + +coco_cates = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +cur_cates = coco_cates \ No newline at end of file diff --git a/exps/muti_label/test.py b/exps/muti_label/test.py index 344ce6b..77af24c 100644 --- a/exps/muti_label/test.py +++ b/exps/muti_label/test.py @@ -1,5 +1,3 @@ -import todd - from mmengine.runner import Runner from mmengine import Config from mmengine.registry import RUNNERS @@ -8,6 +6,6 @@ import exps.muti_label if __name__ == "__main__": - config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config.py") + config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config_coco.py") runner: Runner = RUNNERS.build(config) runner.val() \ No newline at end of file