Skip to content

Commit

Permalink
update trainning
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Nov 20, 2024
1 parent ff3a543 commit 4a5cd45
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 14 deletions.
24 changes: 13 additions & 11 deletions configs/datasets/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set_mode = '_mini'
# split = '_mini'
split = ''
train_batch_size_per_gpu = 2
test_batch_size_per_gpu = 1

Expand Down Expand Up @@ -93,36 +94,37 @@
type='mmyolo.LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
pad_val=dict(img=114),
),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
dict(type='LoadText'),
dict(type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param', 'texts'))
]

obj365v1_train = dict(
type='MultiModalDataset',
dataset=dict(
type='YOLOv5Objects365V2Dataset',
data_root='data/objects365v2/',
ann_file=f'annotations/zhiyuan_objv2_train{set_mode}.json',
ann_file=f'annotations/zhiyuan_objv2_train{split}.json',
data_prefix=dict(img='train/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32)),
class_text_path='data/texts/obj365v2_class_texts.json',
pipeline=train_pipeline_stage1)

mixgrounding_train = dict(type='YOLOv5MixedGroundingDataset',
data_root='data/mixed_grounding/',
ann_file=f'annotations/final_mixed_train_no_coco{set_mode}.json',
ann_file=f'annotations/final_mixed_train_no_coco{split}.json',
data_prefix=dict(img='images/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline_stage1)

flickr_train = dict(
type='YOLOv5MixedGroundingDataset',
data_root='data/flickr/',
ann_file=f'annotations/final_flickr_separateGT_train{set_mode}.json',
ann_file=f'annotations/final_flickr_separateGT_train{split}.json',
data_prefix=dict(img='images/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline_stage1)
Expand All @@ -144,7 +146,7 @@
type='YOLOv5LVISV1Dataset',
data_root='data/lvis/',
test_mode=True,
ann_file='annotations/lvis_v1_minival_inserted_image_name_mini.json',
ann_file=f'annotations/lvis_v1_minival_inserted_image_name{split}.json',
data_prefix=dict(img=''),
batch_shapes_cfg=None
),
Expand All @@ -171,6 +173,6 @@

val_evaluator = dict(
type='mmdet.LVISMetric',
ann_file='data/lvis/annotations/lvis_v1_minival_inserted_image_name_mini.json',
ann_file=f'data/lvis/annotations/lvis_v1_minival_inserted_image_name{split}.json',
metric='bbox'
)
11 changes: 10 additions & 1 deletion configs/models/oadp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
_base_ = ['./faster-rcnn_r50_fpn.py']

cls_predictor_cfg = dict(
type='OADPClassifier',
text_model=dict(
type="HuggingCLIPLanguageBackbone",
model_name='openai/clip-vit-base-patch32',
)
)

model = dict(
type='OADP',
roi_head=dict(bbox_head=dict(num_classes=365))
roi_head=dict(
bbox_head=dict(cls_predictor_cfg=cls_predictor_cfg),
),
)
1 change: 1 addition & 0 deletions oadp/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .detectors import *
from .classifier import *
98 changes: 98 additions & 0 deletions oadp/models/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import itertools
import torch
import torch.nn as nn
from torch.functional import F
from einops import rearrange, repeat
from torch import Tensor
from mmengine.model import BaseModule
from mmdet.registry import MODELS
from mmdet.utils import OptMultiConfig
from transformers import (AutoTokenizer, CLIPTextConfig)
from transformers import CLIPTextModelWithProjection as CLIPTP

from..utils.globals import Globals

@MODELS.register_module()
class HuggingCLIPLanguageBackbone(BaseModule):

def __init__(self,
model_name: str,
dropout: float = 0.0,
init_cfg: OptMultiConfig = None) -> None:

super().__init__(init_cfg=init_cfg)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.config = CLIPTextConfig.from_pretrained(model_name,
attention_dropout=dropout)
self.model = CLIPTP.from_pretrained(model_name, config=self.config)
self.out_dim = self.config.projection_dim
self._freeze_modules()

def forward_tokenizer(self, texts):
if not hasattr(self, 'text'):
text = list(itertools.chain(*texts))
text = self.tokenizer(text=text, return_tensors='pt', padding=True)
self.text = text.to(device=self.model.device)
return self.text

def forward(self, text: list[list[str]]) -> Tensor:
num_per_batch = [len(t) for t in text]
assert max(num_per_batch) == min(num_per_batch), (
'number of sequences not equal in batch')
text = list(itertools.chain(*text))
text = self.tokenizer(text=text, return_tensors='pt', padding=True)
text = text.to(device=self.model.device)
txt_outputs = self.model(**text)
txt_feats = txt_outputs.text_embeds
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
txt_feats = txt_feats.reshape(-1, num_per_batch[0],
txt_feats.shape[-1])
return txt_feats

def _freeze_modules(self):
self.model.eval()
for _, module in self.model.named_modules():
module.eval()
for param in module.parameters():
param.requires_grad = False
return

def train(self, mode=True):
super().train(mode)
self._freeze_modules()


class NormalizedLinear(nn.Linear):

def forward(self, *args, **kwargs) -> torch.Tensor:
x = super().forward(*args, **kwargs)
return F.normalize(x)

@MODELS.register_module()
class OADPClassifier(BaseModule):
def __init__(
self,
*args,
text_model: OptMultiConfig,
in_features: int,
out_features: int,
**kwargs,
):
super().__init__(*args, **kwargs)
self.text_model = MODELS.build(text_model)
self._linear = NormalizedLinear(in_features, self.text_model.out_dim)
self.bg_embedding = nn.Parameter(torch.zeros(1, self.text_model.out_dim))

@property
def texts_feat(self) -> torch.Tensor:
texts_feat = self.text_model(Globals.texts)
b, n, c = texts_feat.shape
bg_embedding = F.normalize(repeat(self.bg_embedding, '1 c -> b 1 c', b=b))
return torch.cat([texts_feat, bg_embedding], dim=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
roi_feat = rearrange(x, '(b n) c -> b n c', n=Globals.sample_num)
x = self._linear(roi_feat)
texts_feat_t = rearrange(self.texts_feat, 'b n c -> b c n')
y = torch.matmul(x, texts_feat_t)
return rearrange(y, 'b n c -> (b n) c')
15 changes: 13 additions & 2 deletions oadp/models/detectors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from mmdet.models import TwoStageDetector
from mmdet.registry import MODELS


from ..utils.globals import Globals

@MODELS.register_module()
class OADP(TwoStageDetector):


def text_bank(self, data_samples):
text_batch_list = []
for data_sample in data_samples:
text_batch_list.append(data_sample.texts)
return text_batch_list

def forward(self, inputs, data_samples = None, mode = 'tensor'):
Globals.texts = self.text_bank(data_samples=data_samples)
if mode == 'loss':
Globals.sample_num = self.train_cfg.rcnn.sampler.num
else:
Globals.sample_num = self.test_cfg.rpn.max_per_img
return super().forward(inputs, data_samples, mode)
9 changes: 9 additions & 0 deletions oadp/utils/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch

class Globals:
"""Entry point for global variables.
Not to be confused with the global distillation branch.
"""
sample_num: int
texts: list[list[str]]
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transformers
albumentations==1.3.1
openmim
mmcv==2.0.0
mmdet
mmyolo
git+https://github.com/lvis-dataset/lvis-api.git

0 comments on commit 4a5cd45

Please sign in to comment.