-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
535 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,66 @@ | ||
{ | ||
|
||
"configurations": [ | ||
{ | ||
"name": "dp objects365", | ||
"type": "debugpy", | ||
"request": "launch", | ||
// 设置 torchrun 命令的参数 | ||
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"args": [ | ||
"--nproc_per_node=1", | ||
"--nnodes=1", | ||
"--master-port=5000", | ||
"-m", | ||
"oadp.dp.train", | ||
"objects365", | ||
"configs/dp/ov_objects365.py" | ||
], | ||
}, | ||
{ | ||
"name": "dp v3det", | ||
"type": "debugpy", | ||
"request": "launch", | ||
// 设置 torchrun 命令的参数 | ||
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"args": [ | ||
"--nproc_per_node=1", | ||
"--nnodes=1", | ||
"--master-port=5000", | ||
"-m", | ||
"oadp.dp.train", | ||
"v3det", | ||
"configs/dp/ov_v3det.py" | ||
], | ||
}, | ||
{ | ||
"name": "dp mixed", | ||
"type": "debugpy", | ||
"request": "launch", | ||
// 设置 torchrun 命令的参数 | ||
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"args": [ | ||
"--nproc_per_node=1", | ||
"--nnodes=1", | ||
"--master-port=5000", | ||
"-m", | ||
"oadp.dp.train", | ||
"mixed", | ||
"configs/dp/ov_mixed.py" | ||
], | ||
} | ||
] | ||
|
||
} | ||
|
||
{ | ||
"name": "Python current File", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"program": "exps/muti_label/test.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false | ||
}, | ||
{ | ||
"name": "dp objects365", | ||
"type": "debugpy", | ||
"request": "launch", | ||
// 设置 torchrun 命令的参数 | ||
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"args": [ | ||
"--nproc_per_node=1", | ||
"--nnodes=1", | ||
"--master-port=5000", | ||
"-m", | ||
"oadp.dp.train", | ||
"objects365", | ||
"configs/dp/ov_objects365.py" | ||
], | ||
}, | ||
{ | ||
"name": "dp v3det", | ||
"type": "debugpy", | ||
"request": "launch", | ||
// 设置 torchrun 命令的参数 | ||
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"args": [ | ||
"--nproc_per_node=1", | ||
"--nnodes=1", | ||
"--master-port=5000", | ||
"-m", | ||
"oadp.dp.train", | ||
"v3det", | ||
"configs/dp/ov_v3det.py" | ||
], | ||
}, | ||
{ | ||
"name": "dp mixed", | ||
"type": "debugpy", | ||
"request": "launch", | ||
// 设置 torchrun 命令的参数 | ||
"program": "/home/jiao/anaconda3/envs/oadp/lib/python3.11/site-packages/torch/distributed/run.py", | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"args": [ | ||
"--nproc_per_node=1", | ||
"--nnodes=1", | ||
"--master-port=5000", | ||
"-m", | ||
"oadp.dp.train", | ||
"mixed", | ||
"configs/dp/ov_mixed.py" | ||
], | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dataset import LVISDataset | ||
from .model import CLIPModel | ||
from .metrics import MutiLabelMetric |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = ['base.py'] | ||
|
||
model = dict( | ||
type='CLIPModel', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import clip | ||
import torch | ||
import os | ||
import numpy as np | ||
from PIL import Image | ||
from lvis.lvis import LVIS | ||
|
||
from mmengine.dataset import BaseDataset | ||
from mmengine.registry import DATASETS, TRANSFORMS | ||
from mmengine.structures import BaseDataElement | ||
|
||
@DATASETS.register_module() | ||
class LVISDataset(BaseDataset): | ||
def __init__(self, *args, categories, **kwargs): | ||
self.categories = categories | ||
super().__init__(*args, **kwargs) | ||
|
||
def parse_data_info(self, raw_data_info): | ||
raw_ann_info = raw_data_info['raw_ann_info'] | ||
raw_img_info = raw_data_info['raw_img_info'] | ||
# to one-hot | ||
category_ids = torch.unique(torch.tensor([ann['category_id'] for ann in raw_ann_info])) - 1 | ||
cate_one_hot = torch.eye(len(self.categories))[category_ids].sum(dim=0) | ||
|
||
return { | ||
"img_path": os.path.join(self.data_root, raw_img_info['file_name']), | ||
"gt_label": cate_one_hot, | ||
} | ||
|
||
def load_data_list(self) -> list[dict]: | ||
|
||
self.lvis = LVIS(self.ann_file) | ||
|
||
img_ids = self.lvis.get_img_ids() | ||
data_list = [] | ||
|
||
for img_id in img_ids: | ||
raw_img_info = self.lvis.load_imgs([img_id])[0] | ||
raw_img_info['img_id'] = img_id | ||
raw_img_info['file_name'] = raw_img_info['coco_url'].replace( | ||
'http://images.cocodataset.org/', '') | ||
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id]) | ||
raw_ann_info = self.lvis.load_anns(ann_ids) | ||
|
||
if len(raw_ann_info) == 0: | ||
print(f"Image {img_id} has no annotations, skipped.") | ||
continue | ||
|
||
parsed_data_info = self.parse_data_info({ | ||
'raw_ann_info': | ||
raw_ann_info, | ||
'raw_img_info': | ||
raw_img_info | ||
}) | ||
data_list.append(parsed_data_info) | ||
|
||
del self.lvis | ||
|
||
return data_list | ||
|
||
@TRANSFORMS.register_module() | ||
class LoadImage: | ||
def __call__(self, data: dict) -> Image.Image: | ||
data['img'] = Image.open(data['img_path']) | ||
return data | ||
|
||
@TRANSFORMS.register_module() | ||
class CLIPTransforms: | ||
def __init__(self) -> None: | ||
_, self.clip_transforms = clip.load("ViT-B/32", device="cpu") | ||
|
||
def __call__(self, data: dict) -> tuple[torch.Tensor, torch.Tensor]: | ||
data['img'] = self.clip_transforms(data['img']) | ||
return data | ||
|
||
@TRANSFORMS.register_module() | ||
class PackData: | ||
def __call__(self, data: dict) -> dict: | ||
packed_results = {} | ||
packed_results['data_samples'] = data['gt_label'] | ||
packed_results['batch_inputs'] = data['img'] | ||
return packed_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from sklearn.metrics import classification_report | ||
|
||
from mmengine.logging import MMLogger | ||
from mmengine.evaluator import BaseMetric | ||
from mmengine.registry import METRICS | ||
|
||
import numpy as np | ||
|
||
|
||
@METRICS.register_module() | ||
class MutiLabelMetric(BaseMetric): | ||
|
||
default_prefix = 'MutiLabel' | ||
|
||
def __init__(self, threshold , categories ,collect_device = 'cpu', prefix = None, collect_dir = None): | ||
super().__init__(collect_device, prefix, collect_dir) | ||
self.threshold = threshold # threshold for classification prediction | ||
self.categories = categories # category labels | ||
|
||
def process(self, data_batch: list[dict], data_samples: list[dict]): | ||
"""Process the data batch and store the classification prediction results""" | ||
pred_label = (data_samples[0]['pred_logits'] > self.threshold) | ||
gt_label = data_samples[0]['gt_label'] | ||
|
||
# fetch classification prediction results and category labels | ||
result = { | ||
'pred': pred_label.cpu().numpy(), | ||
'gt': gt_label.cpu().numpy() | ||
} | ||
|
||
# store the results of the current batch into self.results | ||
self.results.append(result) | ||
|
||
def compute_metrics(self, results: list[dict]) -> dict: | ||
"""Compute the metrics from processed results. | ||
Args: | ||
results (dict): The processed results of each batch. | ||
Returns: | ||
Dict: The computed metrics. The keys are the names of the metrics, | ||
and the values are corresponding results. | ||
""" | ||
|
||
# aggregate the classification prediction results and category labels for all samples | ||
preds = np.concatenate([res['pred'] for res in results]) | ||
gts = np.concatenate([res['gt'] for res in results]) | ||
results = classification_report(gts, preds, target_names=self.categories) | ||
# log the classification report | ||
logger = MMLogger.get_instance('mmengine') | ||
logger.info(results) | ||
return { | ||
'classification_report': results | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import clip | ||
import torch | ||
|
||
from mmengine.model import BaseModel | ||
from mmengine.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class CLIPModel(BaseModel): | ||
def __init__(self, categories: list[str]) -> None: | ||
super().__init__() | ||
self.model, self.preprocess = clip.load("ViT-B/32", device="cpu") | ||
self.cate_text = clip.tokenize(categories) | ||
|
||
@torch.no_grad() | ||
def forward(self, batch_inputs, data_samples, mode='tensor', **kwargs) -> torch.Tensor: | ||
text_features = self.model.encode_text(self.cate_text.cuda()) | ||
image_features = self.model.encode_image(batch_inputs) | ||
|
||
# normalized features | ||
image_features = image_features / image_features.norm(dim=1, keepdim=True) | ||
text_features = text_features / text_features.norm(dim=1, keepdim=True) | ||
|
||
# cosine similarity as logits | ||
logit_scale = self.model.logit_scale.exp() | ||
logits_per_image = logit_scale * image_features @ text_features.t() | ||
|
||
# shape = [global_batch_size, global_batch_size] | ||
pred = { | ||
'pred_logits': logits_per_image, | ||
'gt_label': data_samples | ||
} | ||
return pred, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import todd | ||
|
||
from mmengine.runner import Runner | ||
from mmengine import Config | ||
from mmengine.registry import RUNNERS | ||
from mmengine.runner import Runner | ||
|
||
import exps.muti_label | ||
|
||
if __name__ == "__main__": | ||
config = Config.fromfile("/root/workspace/OADP/exps/muti_label/configs/clip_config.py") | ||
runner: Runner = RUNNERS.build(config) | ||
runner.val() |