Skip to content

Commit

Permalink
update mutilabel classification
Browse files Browse the repository at this point in the history
  • Loading branch information
Noietch committed Nov 15, 2024
1 parent 35af0b7 commit 2f22be9
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 45 deletions.
7 changes: 3 additions & 4 deletions exps/muti_label/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .dataset import LVISDataset, COCODatasets

from .model import CLIPModel
from .metrics import MutiLabelMetric
from .dataset import *
from .model import *
from .metrics import *
10 changes: 7 additions & 3 deletions exps/muti_label/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

val_dataset = dict(
type='LVISDataset',
data_root='data/lvis',
data_root='data/lvis_v1',
ann_file='annotations/lvis_v1_val.json',
pipeline=[
dict(type='LoadImage'),
Expand All @@ -30,5 +30,9 @@
)

work_dir = 'work_dirs/'
launcher = 'pytorch'
log_processor = dict(window_size=1)
# launcher = 'pytorch'
log_processor = dict(window_size=1)
visualizer_cfg = dict(type='Visualizer',
# name='vis',
# save_dir='temp_dir',
vis_backends=[dict(type='WandbVisBackend')])
9 changes: 8 additions & 1 deletion exps/muti_label/configs/clip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@

model = dict(
type='CLIPModel',
)
)

val_evaluator = [
dict(
type='MutiLabelMetric',
threshold=0.1,
),
]
4 changes: 2 additions & 2 deletions exps/muti_label/configs/clip_config_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

val_dataset = dict(
type='COCODatasets',
data_root='data/coco2017',
data_root='data/coco',
data_prefix=dict(
img_path='val2017',
),
Expand All @@ -15,7 +15,7 @@
)

val_dataloader = dict(
batch_size=32,
batch_size=64,
dataset=val_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate')
Expand Down
32 changes: 32 additions & 0 deletions exps/muti_label/configs/ram_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
_base_ = ['base.py']

val_dataset = dict(
type='LVISDataset',
data_root='data/lvis_v1',
ann_file='annotations/lvis_v1_val.json',
pipeline=[
dict(type='LoadImage'),
dict(type='RAMTransforms'),
dict(type='PackData'),
],
)

val_dataloader = dict(
batch_size=32,
dataset=val_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate')
)

model = dict(
type='RAMModel',
model_path='pretrained/ram/ram_plus_swin_large_14m.pth',
llm_tag_des='data/lvis_v1/annotations/openset_label_embedding.pth'
)

val_evaluator = [
dict(
type='MutiLabelMetric',
threshold=0.5,
),
]
23 changes: 19 additions & 4 deletions exps/muti_label/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import clip
import torch
import os
import numpy as np
from PIL import Image
from lvis.lvis import LVIS
from pycocotools.coco import COCO
from ram import get_transform

from mmengine.structures.base_data_element import BaseDataElement
from mmengine.dataset import BaseDataset
from mmengine.registry import DATASETS, TRANSFORMS

Expand Down Expand Up @@ -72,7 +73,8 @@ def parse_data_info(self, raw_data_info):
# to one-hot
category_ids = torch.unique(torch.tensor([self.cat2label[ann['category_id']] for ann in raw_ann_info]))
cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0)

if (cate_one_hot == 0).all():
print(f"Image {raw_img_info['img_id']} has no annotations, skipped.")
return {
"img_path": os.path.join(self.data_prefix['img_path'], raw_img_info['file_name']),
"gt_label": cate_one_hot,
Expand Down Expand Up @@ -123,11 +125,24 @@ def __init__(self) -> None:
def __call__(self, data: dict) -> tuple[torch.Tensor, torch.Tensor]:
data['img'] = self.clip_transforms(data['img'])
return data


@TRANSFORMS.register_module()
class RAMTransforms:
def __init__(self) -> None:
self.transform = get_transform(image_size=384)

def __call__(self, data: dict) -> dict:
data['img'] = self.transform(data['img'])
return data


@TRANSFORMS.register_module()
class PackData:
def __call__(self, data: dict) -> dict:
packed_results = {}
packed_results['data_samples'] = data['gt_label']
packed_results['data_samples'] = BaseDataElement(
metainfo=dict(img_path=data['img_path']),
gt_label=data['gt_label'],
)
packed_results['batch_inputs'] = data['img']
return packed_results
2 changes: 1 addition & 1 deletion exps/muti_label/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,4 @@
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

cur_cates = coco_cates
cur_cates = lvis_cates
83 changes: 56 additions & 27 deletions exps/muti_label/metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import cv2
import os
import torch
import numpy as np
from sklearn.metrics import accuracy_score, recall_score

from mmengine.logging import MMLogger
import mmcv
from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS
from mmengine.visualization import Visualizer

from exps.muti_label.globals import cur_cates


@METRICS.register_module()
class MutiLabelMetric(BaseMetric):

Expand All @@ -15,43 +18,69 @@ class MutiLabelMetric(BaseMetric):
def __init__(self, threshold, collect_device = 'cpu', prefix = None, collect_dir = None):
super().__init__(collect_device, prefix, collect_dir)
self.threshold = threshold # threshold for classification prediction
self.is_visualize = True

def process(self, data_batch: list[dict], data_samples: list[dict]):
"""Process the data batch and store the classification prediction results"""
pred_label = (data_samples[0]['pred_logits'] > self.threshold)
gt_label = data_samples[0]['gt_label']
def _average_precision(sekf, output: np.ndarray, target: np.ndarray) -> float:
epsilon = 1e-8

# sort examples
indices = output.argsort()[::-1]
# Computes prec@i
total_count_ = np.cumsum(np.ones((len(output), 1)))

target_ = target[indices]
ind = target_ == 1
pos_count_ = np.cumsum(ind)
total = pos_count_[-1]
pos_count_[np.logical_not(ind)] = 0
pp = pos_count_ / total_count_
precision_at_i_ = np.sum(pp)
precision_at_i = precision_at_i_ / (total + epsilon)

return precision_at_i

def get_mAP(self, gts, preds):
APs = []
_, num_classes = gts.shape
APs = np.zeros(num_classes)
for k in range(num_classes): # AP for each class
APs[k] = self._average_precision(preds[:, k], gts[:, k])
return APs.mean()

# d_label = np.argwhere(pred_label[0].cpu().numpy() == 1)[0]
# print([cur_cates[i] for i in d_label])
def select_cates(self, preds: torch.Tensor):
cates = [cur_cates[i] for i in range(len(cur_cates)) if preds[i]]
return ', '.join(cates)

# fetch classification prediction results and category labels
def visualize(self, data_samples: list[dict], pred_labels: torch.Tensor):
self.visualizer: Visualizer = Visualizer.get_current_instance()
for sample, pred in zip(data_samples[0]['data_samples'], pred_labels):
img = mmcv.imread(sample.img_path, channel_order='rgb')
img_name = os.path.basename(sample.img_path)
self.visualizer.set_image(img)
text = self.select_cates(pred.cpu().numpy())
self.visualizer.draw_texts(text, torch.tensor([10, 20]))
self.visualizer.add_image(img_name, self.visualizer.get_image())

def process(self, data_batch: list[dict], data_samples: list[dict]):
pred_label = data_samples[0]['pred_logits']
gt_label = torch.cat([sample.gt_label.unsqueeze(0) for sample in data_samples[0]['data_samples']], dim=0)

result = {
'pred': pred_label.cpu().numpy(),
'gt': gt_label.cpu().numpy()
}

# store the results of the current batch into self.results
self.results.append(result)

def compute_metrics(self, results: list[dict]) -> dict:
"""Compute the metrics from processed results.
if self.is_visualize:
self.visualize(data_samples, pred_label > self.threshold)
self.is_visualize = False

Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

# aggregate the classification prediction results and category labels for all samples
def compute_metrics(self, results: list[dict]) -> dict:
preds = np.concatenate([res['pred'] for res in results])
gts = np.concatenate([res['gt'] for res in results])
accuracy = accuracy_score(gts, preds)
recall = recall_score(gts, preds, average='macro')
mAP = self.get_mAP(gts, preds)
# # log the classification report
results = {
'accuracy': accuracy,
'recall': recall
'mAP': mAP,
}
return results
74 changes: 72 additions & 2 deletions exps/muti_label/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import clip
import torch
import json
import numpy as np
import torch.nn as nn
from torch.nn import functional as F

from ram.models import ram_plus

from mmengine.model import BaseModel
from mmengine.registry import MODELS
Expand All @@ -17,9 +23,73 @@ def __init__(self) -> None:
def forward(self, batch_inputs, data_samples, mode='tensor', **kwargs) -> torch.Tensor:
logits_per_image, logits_per_text = self.model(batch_inputs, self.cate_text.cuda())
probs = logits_per_image.softmax(dim=-1)
# shape = [global_batch_size, global_batch_size]
pred = {
'pred_logits': probs,
'gt_label': data_samples
'data_samples': data_samples
}
return pred, None

@MODELS.register_module()
class RAMModel(BaseModel):
def __init__(self, model_path: str, llm_tag_des:str, image_size=384) -> None:
super().__init__()
self.model = ram_plus(pretrained=model_path, image_size=image_size, vit='swin_l')

label_info = torch.load(llm_tag_des)
openset_label_embedding = label_info['openset_label_embedding']
openset_categories = label_info['openset_categories']

self.model.tag_list = np.array(openset_categories)
self.model.label_embed = nn.Parameter(openset_label_embedding.float())
self.model.num_class = len(openset_categories)
self.model.class_threshold = torch.ones(self.model.num_class) * 0.5

def model_forward(self, image):

image_embeds = self.model.image_proj(self.model.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(image.device)

image_cls_embeds = image_embeds[:, 0, :]
image_spatial_embeds = image_embeds[:, 1:, :]

bs = image_spatial_embeds.shape[0]

des_per_class = int(self.model.label_embed.shape[0] / self.model.num_class)

image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
reweight_scale = self.model.reweight_scale.exp()
logits_per_image = (reweight_scale * image_cls_embeds @ self.model.label_embed.t())
logits_per_image = logits_per_image.view(bs, -1,des_per_class)

weight_normalized = F.softmax(logits_per_image, dim=2)
label_embed_reweight = torch.empty(bs, self.model.num_class, 512).to(image.device).to(image.dtype)

for i in range(bs):
# 这里对 value_ori 进行 reshape,然后使用 broadcasting
reshaped_value = self.model.label_embed.view(-1, des_per_class, 512)
product = weight_normalized[i].unsqueeze(-1) * reshaped_value
label_embed_reweight[i] = product.sum(dim=1)

label_embed = torch.nn.functional.relu(self.model.wordvec_proj(label_embed_reweight))

# recognized image tags using alignment decoder
tagging_embed = self.model.tagging_head(
encoder_embeds=label_embed,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=False,
mode='tagging',
)

logits = self.model.fc(tagging_embed[0]).squeeze(-1)
return logits

@torch.no_grad()
def forward(self, batch_inputs, data_samples, mode='tensor', **kwargs) -> torch.Tensor:
probs = self.model_forward(batch_inputs)
pred = {
'pred_logits': probs,
'data_samples': data_samples
}
return pred, None
2 changes: 1 addition & 1 deletion exps/muti_label/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
import exps.muti_label

if __name__ == "__main__":
config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config_coco.py")
config = Config.fromfile("exps/muti_label/configs/ram_config.py")
runner: Runner = RUNNERS.build(config)
runner.val()

0 comments on commit 2f22be9

Please sign in to comment.