Skip to content

Commit

Permalink
feat: dp train resume
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Sep 10, 2024
1 parent 7da13c4 commit c3dd597
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ wget https://raw.githubusercontent.com/OPPOMKLab/recognize-anything/main/dataset

```bash
bash tools/torchrun.sh -m oadp.prompts.val lvis_clip --config type::LVISPrompter --model type::CLIP
bash tools/torchrun.sh -m oadp.prompts.val lvis_t5 --config type::LVISPrompter --model type::T5
```

Download `ml_coco.pth` from [Baidu Netdisk][].
Expand Down
4 changes: 0 additions & 4 deletions configs/dp/datasets/objects365.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
_base_ = [
'coco_detection.py',
]

categories = 'objects365'
dataset_type = 'Objects365V2Dataset'
data_root = 'data/objects365v2/'
Expand Down
4 changes: 2 additions & 2 deletions configs/dp/schedules/2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=2.5e-5),
optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.),
)

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
auto_scale_lr = dict(enable=True, base_batch_size=16)

# checkpoint
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
Expand Down
5 changes: 2 additions & 3 deletions oadp/dp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def main():
autocast=args.autocast,
)
# log(trainer, args, config)
if args.load_model_from:
# trainer.strategy.load_model_from(args.load_model_from, strict=False)
raise ValueError("load_model_from is not supported")
# if args.load_model_from:
# trainer.strategy.load_model_from(args.load_model_from, strict=False)
# trainer.run()
trainer.train()

Expand Down
23 changes: 17 additions & 6 deletions oadp/prompts/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,47 @@

from typing import Iterable

import todd
import torch
from transformers import T5EncoderModel, T5Tokenizer
import torch.nn.functional as F
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
from transformers.modeling_outputs import BaseModelOutput
from todd.registries import InitWeightsMixin

from ..registries import PromptModelRegistry
from .base import BaseModel


@PromptModelRegistry.register_()
class T5(BaseModel):
class T5(InitWeightsMixin, BaseModel):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
tokenizer = T5Tokenizer.from_pretrained('pretrained/t5/t5-large')
tokenizer: T5Tokenizer = AutoTokenizer.from_pretrained(
'pretrained/t5/t5-large',
)
model = T5EncoderModel.from_pretrained('pretrained/t5/t5-large')
model = model.requires_grad_(False)
model = model.eval()
self._tokenizer = tokenizer
self._model = model

def init_weights(self, config: todd.Config) -> bool:
super().init_weights(config)
return False

def forward(
self,
texts: Iterable[str],
batch_size: str | None = None,
) -> torch.Tensor:
assert batch_size is None
embeddings: list[torch.Tensor] = []
for text in texts:
tokens = self._tokenizer(text, return_tensors='pt')
if todd.Store.cuda:
tokens = tokens.to('cuda')
outputs: BaseModelOutput = self._model(**tokens)
embedding = outputs.last_hidden_state.mean(0)
embedding = outputs.last_hidden_state.mean(1)
embeddings.append(embedding)
return torch.stack(embeddings)
embedding = torch.cat(embeddings)
return F.normalize(embedding)
3 changes: 2 additions & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ pipenv run pip install \
scikit-learn \

pipenv run mim install mmcv
pipenv run mim install mmdet --no-deps
pipenv run mim install mmdet --no-deps # mmdet requires mmcv<2.2.0
pipenv run pip install shapely terminaltables # mmdet dependencies

pipenv run pip install \
git+https://github.com/lvis-dataset/lvis-api.git@lvis_challenge_2021
Expand Down

0 comments on commit c3dd597

Please sign in to comment.