Skip to content

Commit

Permalink
feat: detpro soco
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Sep 18, 2024
1 parent c3dd597 commit 90c9046
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 13 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ To conduct training for coco
To conduct training for lvis

```bash
[DRY_RUN=True] [TRAIN_WITH_VAL_DATASET=True] bash tools/torchrun.sh -m oadp.dp.train ov_lvis configs/dp/ov_lvis.py
bash tools/torchrun.sh -m oadp.dp.train ov_lvis configs/dp/ov_lvis.py
```

To test a specific checkpoint

```bash
[DRY_RUN=True] bash tools/torchrun.sh -m oadp.dp.test configs/dp/oadp_ov_coco.py work_dirs/oadp_ov_coco/iter_32000.pth
[DRY_RUN=True] bash tools/torchrun.sh -m oadp.dp.test configs/dp/oadp_ov_lvis.py work_dirs/oadp_ov_lvis/epoch_24.pth
bash tools/torchrun.sh -m oadp.dp.test ov_coco configs/dp/ov_coco.py --load-model-from work_dirs/ov_coco/epoch_24.pth --visual xxx
bash tools/torchrun.sh -m oadp.dp.test ov_lvis configs/dp/ov_lvis.py --load-model-from work_dirs/ov_lvis/epoch_24.pth --visual xxx
```

For the instance segmentation performance on LVIS, use the `metrics` argument
Expand Down
2 changes: 1 addition & 1 deletion configs/dp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
dist_cfg=dict(timeout=1800),
)

vis_backends = [dict(type='LocalVisBackend')]
Expand Down
14 changes: 11 additions & 3 deletions configs/dp/ov_lvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
]

cls_predictor_cfg = dict(
type='OVClassifier',
scaler=dict(train=0.01, val=0.007),
# type='OVClassifier',
# scaler=dict(train=0.01, val=0.007),
type='ViLDClassifier',
prompts='data/prompts/detpro_lvis.pth',
scaler=dict(
train=0.01,
val=0.007,
),
)
model = dict(
global_head=dict(
classifier=dict(
**cls_predictor_cfg,
# **cls_predictor_cfg,
type='ViLDClassifier',
prompts='data/prompts/detpro_lvis.pth',
out_features=1203,
),
),
Expand Down
14 changes: 9 additions & 5 deletions oadp/categories/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from torch import nn
import torch.nn.functional as F

from ...utils import Globals

Expand Down Expand Up @@ -50,8 +51,11 @@ def get_embeddings(self) -> list[torch.Tensor]:
def forward(self) -> torch.Tensor:
embeddings = self.get_embeddings()

embeddings = [
embedding[random.randrange(embedding.shape[0])]
for embedding in embeddings
]
return torch.stack(embeddings)
# embeddings = [
# embedding[random.randrange(embedding.shape[0])]
# for embedding in embeddings
# ]
# return torch.stack(embeddings)

embeddings = [embedding.mean(0) for embedding in embeddings]
return F.normalize(torch.stack(embeddings))
7 changes: 7 additions & 0 deletions oadp/dp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@ def main():
Globals.categories = Categories.get(config.categories)

# trainer = DPRunnerRegistry.build(
# config,
# name=args.name,
# load_from=args.load_from,
# auto_resume=args.auto_resume,
# autocast=args.autocast,
# )
trainer = DPRunner.from_cfg(
config,
name=args.name,
load_from=args.load_from,
load_model_from=args.load_model_from[0],
auto_resume=args.auto_resume,
autocast=args.autocast,
)
Expand Down
4 changes: 3 additions & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ curl https://raw.githubusercontent.com/LutingWang/todd/main/bin/pipenv_install |

git config --global --add safe.directory $(dirname $(realpath $0))

pipenv run pip install -i https://download.pytorch.org/whl/cu118 torch==2.4.0+cu118 torchvision==0.19.0+cu118
# pipenv run pip install -i https://download.pytorch.org/whl/cu118 torch==2.4.0+cu118 torchvision==0.19.0+cu118
pipenv run pip install /mnt/bn/wangluting/wheels/torch-2.4.0+cu118-cp311-cp311-linux_x86_64.whl
pipenv run pip install -i https://download.pytorch.org/whl/cu118 torchvision==0.19.0+cu118
pipenv run pip install \
nni \
openmim \
Expand Down

0 comments on commit 90c9046

Please sign in to comment.