From 90c9046b2d8bed7daef44b2392226e099badb5c3 Mon Sep 17 00:00:00 2001 From: "wangluting.wlt" Date: Wed, 18 Sep 2024 22:15:03 +0800 Subject: [PATCH] feat: detpro soco --- README.md | 6 +++--- configs/dp/base.py | 2 +- configs/dp/ov_lvis.py | 14 +++++++++++--- oadp/categories/embeddings/base.py | 14 +++++++++----- oadp/dp/train.py | 7 +++++++ setup.sh | 4 +++- 6 files changed, 34 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index a8e51b3..cb28678 100644 --- a/README.md +++ b/README.md @@ -114,14 +114,14 @@ To conduct training for coco To conduct training for lvis ```bash -[DRY_RUN=True] [TRAIN_WITH_VAL_DATASET=True] bash tools/torchrun.sh -m oadp.dp.train ov_lvis configs/dp/ov_lvis.py +bash tools/torchrun.sh -m oadp.dp.train ov_lvis configs/dp/ov_lvis.py ``` To test a specific checkpoint ```bash -[DRY_RUN=True] bash tools/torchrun.sh -m oadp.dp.test configs/dp/oadp_ov_coco.py work_dirs/oadp_ov_coco/iter_32000.pth -[DRY_RUN=True] bash tools/torchrun.sh -m oadp.dp.test configs/dp/oadp_ov_lvis.py work_dirs/oadp_ov_lvis/epoch_24.pth +bash tools/torchrun.sh -m oadp.dp.test ov_coco configs/dp/ov_coco.py --load-model-from work_dirs/ov_coco/epoch_24.pth --visual xxx +bash tools/torchrun.sh -m oadp.dp.test ov_lvis configs/dp/ov_lvis.py --load-model-from work_dirs/ov_lvis/epoch_24.pth --visual xxx ``` For the instance segmentation performance on LVIS, use the `metrics` argument diff --git a/configs/dp/base.py b/configs/dp/base.py index 94ef1d4..d2513cb 100644 --- a/configs/dp/base.py +++ b/configs/dp/base.py @@ -11,7 +11,7 @@ env_cfg = dict( cudnn_benchmark=False, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), - dist_cfg=dict(backend='nccl'), + dist_cfg=dict(timeout=1800), ) vis_backends = [dict(type='LocalVisBackend')] diff --git a/configs/dp/ov_lvis.py b/configs/dp/ov_lvis.py index 656c523..0d39007 100644 --- a/configs/dp/ov_lvis.py +++ b/configs/dp/ov_lvis.py @@ -7,13 +7,21 @@ ] cls_predictor_cfg = dict( - type='OVClassifier', - scaler=dict(train=0.01, val=0.007), + # type='OVClassifier', + # scaler=dict(train=0.01, val=0.007), + type='ViLDClassifier', + prompts='data/prompts/detpro_lvis.pth', + scaler=dict( + train=0.01, + val=0.007, + ), ) model = dict( global_head=dict( classifier=dict( - **cls_predictor_cfg, + # **cls_predictor_cfg, + type='ViLDClassifier', + prompts='data/prompts/detpro_lvis.pth', out_features=1203, ), ), diff --git a/oadp/categories/embeddings/base.py b/oadp/categories/embeddings/base.py index 3f3f16a..6af7156 100644 --- a/oadp/categories/embeddings/base.py +++ b/oadp/categories/embeddings/base.py @@ -6,6 +6,7 @@ import torch from torch import nn +import torch.nn.functional as F from ...utils import Globals @@ -50,8 +51,11 @@ def get_embeddings(self) -> list[torch.Tensor]: def forward(self) -> torch.Tensor: embeddings = self.get_embeddings() - embeddings = [ - embedding[random.randrange(embedding.shape[0])] - for embedding in embeddings - ] - return torch.stack(embeddings) + # embeddings = [ + # embedding[random.randrange(embedding.shape[0])] + # for embedding in embeddings + # ] + # return torch.stack(embeddings) + + embeddings = [embedding.mean(0) for embedding in embeddings] + return F.normalize(torch.stack(embeddings)) diff --git a/oadp/dp/train.py b/oadp/dp/train.py index b1ebf4e..9015a43 100644 --- a/oadp/dp/train.py +++ b/oadp/dp/train.py @@ -37,10 +37,17 @@ def main(): Globals.categories = Categories.get(config.categories) # trainer = DPRunnerRegistry.build( + # config, + # name=args.name, + # load_from=args.load_from, + # auto_resume=args.auto_resume, + # autocast=args.autocast, + # ) trainer = DPRunner.from_cfg( config, name=args.name, load_from=args.load_from, + load_model_from=args.load_model_from[0], auto_resume=args.auto_resume, autocast=args.autocast, ) diff --git a/setup.sh b/setup.sh index 902f7d3..83fcf03 100644 --- a/setup.sh +++ b/setup.sh @@ -4,7 +4,9 @@ curl https://raw.githubusercontent.com/LutingWang/todd/main/bin/pipenv_install | git config --global --add safe.directory $(dirname $(realpath $0)) -pipenv run pip install -i https://download.pytorch.org/whl/cu118 torch==2.4.0+cu118 torchvision==0.19.0+cu118 +# pipenv run pip install -i https://download.pytorch.org/whl/cu118 torch==2.4.0+cu118 torchvision==0.19.0+cu118 +pipenv run pip install /mnt/bn/wangluting/wheels/torch-2.4.0+cu118-cp311-cp311-linux_x86_64.whl +pipenv run pip install -i https://download.pytorch.org/whl/cu118 torchvision==0.19.0+cu118 pipenv run pip install \ nni \ openmim \