Skip to content

Commit

Permalink
add multi-label
Browse files Browse the repository at this point in the history
  • Loading branch information
Noietch committed Nov 10, 2024
1 parent 94cd475 commit bb789a6
Show file tree
Hide file tree
Showing 8 changed files with 535 additions and 59 deletions.
123 changes: 64 additions & 59 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,61 +1,66 @@
{

"configurations": [
{
"name": "dp objects365",
"type": "debugpy",
"request": "launch",
// 设置 torchrun 命令的参数
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--nproc_per_node=1",
"--nnodes=1",
"--master-port=5000",
"-m",
"oadp.dp.train",
"objects365",
"configs/dp/ov_objects365.py"
],
},
{
"name": "dp v3det",
"type": "debugpy",
"request": "launch",
// 设置 torchrun 命令的参数
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--nproc_per_node=1",
"--nnodes=1",
"--master-port=5000",
"-m",
"oadp.dp.train",
"v3det",
"configs/dp/ov_v3det.py"
],
},
{
"name": "dp mixed",
"type": "debugpy",
"request": "launch",
// 设置 torchrun 命令的参数
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--nproc_per_node=1",
"--nnodes=1",
"--master-port=5000",
"-m",
"oadp.dp.train",
"mixed",
"configs/dp/ov_mixed.py"
],
}
]

}

{
"name": "Python current File",
"type": "debugpy",
"request": "launch",
"program": "exps/muti_label/test.py",
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "dp objects365",
"type": "debugpy",
"request": "launch",
// 设置 torchrun 命令的参数
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--nproc_per_node=1",
"--nnodes=1",
"--master-port=5000",
"-m",
"oadp.dp.train",
"objects365",
"configs/dp/ov_objects365.py"
],
},
{
"name": "dp v3det",
"type": "debugpy",
"request": "launch",
// 设置 torchrun 命令的参数
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--nproc_per_node=1",
"--nnodes=1",
"--master-port=5000",
"-m",
"oadp.dp.train",
"v3det",
"configs/dp/ov_v3det.py"
],
},
{
"name": "dp mixed",
"type": "debugpy",
"request": "launch",
// 设置 torchrun 命令的参数
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--nproc_per_node=1",
"--nnodes=1",
"--master-port=5000",
"-m",
"oadp.dp.train",
"mixed",
"configs/dp/ov_mixed.py"
],
}
]
}
3 changes: 3 additions & 0 deletions exps/muti_label/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dataset import LVISDataset
from .model import CLIPModel
from .metrics import MutiLabelMetric
281 changes: 281 additions & 0 deletions exps/muti_label/configs/base.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions exps/muti_label/configs/clip_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['base.py']

model = dict(
type='CLIPModel',
)
82 changes: 82 additions & 0 deletions exps/muti_label/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import clip
import torch
import os
import numpy as np
from PIL import Image
from lvis.lvis import LVIS

from mmengine.dataset import BaseDataset
from mmengine.registry import DATASETS, TRANSFORMS
from mmengine.structures import BaseDataElement

@DATASETS.register_module()
class LVISDataset(BaseDataset):
def __init__(self, *args, categories, **kwargs):
self.categories = categories
super().__init__(*args, **kwargs)

def parse_data_info(self, raw_data_info):
raw_ann_info = raw_data_info['raw_ann_info']
raw_img_info = raw_data_info['raw_img_info']
# to one-hot
category_ids = torch.unique(torch.tensor([ann['category_id'] for ann in raw_ann_info])) - 1
cate_one_hot = torch.eye(len(self.categories))[category_ids].sum(dim=0)

return {
"img_path": os.path.join(self.data_root, raw_img_info['file_name']),
"gt_label": cate_one_hot,
}

def load_data_list(self) -> list[dict]:

self.lvis = LVIS(self.ann_file)

img_ids = self.lvis.get_img_ids()
data_list = []

for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
raw_img_info['file_name'] = raw_img_info['coco_url'].replace(
'http://images.cocodataset.org/', '')
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)

if len(raw_ann_info) == 0:
print(f"Image {img_id} has no annotations, skipped.")
continue

parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)

del self.lvis

return data_list

@TRANSFORMS.register_module()
class LoadImage:
def __call__(self, data: dict) -> Image.Image:
data['img'] = Image.open(data['img_path'])
return data

@TRANSFORMS.register_module()
class CLIPTransforms:
def __init__(self) -> None:
_, self.clip_transforms = clip.load("ViT-B/32", device="cpu")

def __call__(self, data: dict) -> tuple[torch.Tensor, torch.Tensor]:
data['img'] = self.clip_transforms(data['img'])
return data

@TRANSFORMS.register_module()
class PackData:
def __call__(self, data: dict) -> dict:
packed_results = {}
packed_results['data_samples'] = data['gt_label']
packed_results['batch_inputs'] = data['img']
return packed_results
54 changes: 54 additions & 0 deletions exps/muti_label/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from sklearn.metrics import classification_report

from mmengine.logging import MMLogger
from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS

import numpy as np


@METRICS.register_module()
class MutiLabelMetric(BaseMetric):

default_prefix = 'MutiLabel'

def __init__(self, threshold , categories ,collect_device = 'cpu', prefix = None, collect_dir = None):
super().__init__(collect_device, prefix, collect_dir)
self.threshold = threshold # threshold for classification prediction
self.categories = categories # category labels

def process(self, data_batch: list[dict], data_samples: list[dict]):
"""Process the data batch and store the classification prediction results"""
pred_label = (data_samples[0]['pred_logits'] > self.threshold)
gt_label = data_samples[0]['gt_label']

# fetch classification prediction results and category labels
result = {
'pred': pred_label.cpu().numpy(),
'gt': gt_label.cpu().numpy()
}

# store the results of the current batch into self.results
self.results.append(result)

def compute_metrics(self, results: list[dict]) -> dict:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

# aggregate the classification prediction results and category labels for all samples
preds = np.concatenate([res['pred'] for res in results])
gts = np.concatenate([res['gt'] for res in results])
results = classification_report(gts, preds, target_names=self.categories)
# log the classification report
logger = MMLogger.get_instance('mmengine')
logger.info(results)
return {
'classification_report': results
}
33 changes: 33 additions & 0 deletions exps/muti_label/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import clip
import torch

from mmengine.model import BaseModel
from mmengine.registry import MODELS


@MODELS.register_module()
class CLIPModel(BaseModel):
def __init__(self, categories: list[str]) -> None:
super().__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cpu")
self.cate_text = clip.tokenize(categories)

@torch.no_grad()
def forward(self, batch_inputs, data_samples, mode='tensor', **kwargs) -> torch.Tensor:
text_features = self.model.encode_text(self.cate_text.cuda())
image_features = self.model.encode_image(batch_inputs)

# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)

# cosine similarity as logits
logit_scale = self.model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()

# shape = [global_batch_size, global_batch_size]
pred = {
'pred_logits': logits_per_image,
'gt_label': data_samples
}
return pred, None
13 changes: 13 additions & 0 deletions exps/muti_label/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import todd

from mmengine.runner import Runner
from mmengine import Config
from mmengine.registry import RUNNERS
from mmengine.runner import Runner

import exps.muti_label

if __name__ == "__main__":
config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config.py")
runner: Runner = RUNNERS.build(config)
runner.val()

0 comments on commit bb789a6

Please sign in to comment.