Skip to content

Commit

Permalink
feat: soco_detpro_bs16
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Sep 27, 2024
1 parent dfafc14 commit 8dbe866
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 34 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ ipython_config.py
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
Pipfile
Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
Expand Down
29 changes: 22 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ bash setup.sh
<https://toddai.readthedocs.io/en/latest/data/lvis.html>

```bash
python -m oadp.build_annotations
python tools/build_annotations.py
```

The following files will be generated
Expand All @@ -49,11 +49,15 @@ OADP/data

## Pretrained Models

<https://toddai.readthedocs.io/en/latest/data/wordnet.html>

<https://toddai.readthedocs.io/en/latest/pretrained/clip.html>

<https://toddai.readthedocs.io/en/latest/pretrained/ram.html>

```shell
mkdir pretrained
python -c "import torchvision; _ = torchvision.models.ResNet50_Weights.IMAGENET1K_V1.get_state_dict(True)"
mkdir -p pretrained
python -c "import torchvision; _ = torchvision.models.ResNet50_Weights.IMAGENET1K_V1.get_state_dict()"
ln -s ~/.cache/torch/hub/checkpoints/ pretrained/torch
```

Expand All @@ -76,6 +80,17 @@ OADP/data/prompts
└── ml_coco.pth
```

```bash
mkdir -p pretrained/detpro
bypy downfile iou_neg5_ens.pth pretrained/detpro
python -m oadp.prompts.detpro
```

```bash
mkdir -p pretrained/soco
bypy downfile current_mmdetection_Head.pth pretrained/soco/soco_star_mask_rcnn_r50_fpn_400e.pth
```

## OAKE

The following scripts extract features with CLIP, which can be very time-consuming. Therefore, all the scripts support automatically resuming, by skipping existing feature files. However, the existing feature files are sometimes broken. In such cases, users can set the `auto_fix` option to inspect the integrity of each feature file.
Expand All @@ -100,9 +115,9 @@ The number of files generated by OAKE-objects may be less than the number of ima
Images without objects are skipped.

```bash
bash tools/torchrun.sh tools/generate_sample_images.py coco
python tools/encode_sample_images.py coco
python tools/sample_visual_category_embeddings.py coco clip
bash tools/torchrun.sh tools/generate_sample_images.py lvis
python tools/encode_sample_images.py lvis
python tools/sample_visual_category_embeddings.py lvis clip
```

## DP
Expand All @@ -116,7 +131,7 @@ To conduct training for coco
To conduct training for lvis

```bash
bash tools/torchrun.sh -m oadp.dp.train ov_lvis configs/dp/ov_lvis.py
bash tools/torchrun.sh -m oadp.dp.train ov_lvis configs/dp/ov_lvis.py # --load-model-from pretrained/soco/soco_star_mask_rcnn_r50_fpn_400e.pth
```

To test a specific checkpoint
Expand Down
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ latest_todd_version := $(shell curl -H "Accept: application/vnd.github.sha" -s h
define install_todd
pipenv run pip uninstall -y todd_ai
GIT_LFS_SKIP_SMUDGE=1 pipenv run pip install \
git+https://github.com/LutingWang/todd.git@$(1)#egg=todd_ai\[optional\]
git+https://github.com/LutingWang/todd.git@$(1)#egg=todd_ai\[optional,dev,lint,doc,test\]
pipenv run pip uninstall -y opencv-python opencv-python-headless
pipenv run pip install opencv-python-headless
endef
Expand Down
3 changes: 2 additions & 1 deletion oadp/dp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def main():
config,
name=args.name,
load_from=args.load_from,
load_model_from=args.load_model_from[0],
load_model_from=args.load_model_from[0]
if args.load_model_from else None,
auto_resume=args.auto_resume,
autocast=args.autocast,
)
Expand Down
22 changes: 22 additions & 0 deletions oadp/prompts/detpro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from mmdet.datasets import LVISV1Dataset


def main() -> None:
embeddings = torch.load('pretrained/detpro/iou_neg5_ens.pth', 'cpu')

# lvis annotations have a typo, which is fixed in mmdet
# we need to change it back, so that the names match
names: list[str] = list(LVISV1Dataset.METAINFO['classes'])
i = names.index('speaker_(stereo_equipment)')
names[i] = 'speaker_(stero_equipment)'

state_dict = dict(
embeddings=embeddings,
names=names,
)
torch.save(state_dict, 'data/prompts/detpro_lvis.pth')


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ plugins = 'numpy.typing.mypy_plugin'

[[tool.mypy.overrides]]
module = [
'lvis.*',
'mmdet.*',
'nltk.*',
]
ignore_missing_imports = true
Expand Down
10 changes: 4 additions & 6 deletions setup.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
set -e

curl https://raw.githubusercontent.com/LutingWang/todd/main/bin/pipenv_install | bash -s -- 3.11.3
curl https://raw.githubusercontent.com/LutingWang/todd/main/bin/pipenv_install | bash -s -- 3.11.10

git config --global --add safe.directory $(dirname $(realpath $0))
pipenv run pip install /data/wlt/wheels/torch-2.4.1+cu121-cp311-cp311-linux_x86_64.whl
pipenv run pip install -i https://download.pytorch.org/whl/cu121 torchvision==0.19.1+cu121

# pipenv run pip install -i https://download.pytorch.org/whl/cu118 torch==2.4.0+cu118 torchvision==0.19.0+cu118
pipenv run pip install /mnt/bn/wangluting/wheels/torch-2.4.0+cu118-cp311-cp311-linux_x86_64.whl
pipenv run pip install -i https://download.pytorch.org/whl/cu118 torchvision==0.19.0+cu118
pipenv run pip install \
nni \
openmim \
scikit-learn \
scikit-learn

pipenv run mim install mmcv
pipenv run mim install mmdet --no-deps # mmdet requires mmcv<2.2.0
Expand Down
30 changes: 14 additions & 16 deletions tools/build_annotations.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import json
import pathlib
import sys
from abc import ABC, abstractmethod
from typing import Any

import todd
from lvis import LVIS
from mmdet.datasets.api_wrappers import COCO

from oadp.models import Categories, coco, lvis
sys.path.insert(0, '')
from oadp.categories import Categories, coco, lvis # noqa: E402 E501 isort:skip pylint: disable=wrong-import-position
from mmdet.datasets.api_wrappers import COCO # noqa: E402 E501 isort:skip pylint: disable=wrong-import-position,wrong-import-order

Data = dict[str, Any]


class Builder(ABC):
CATEGORIES: Categories

def __init__(self, categories: Categories, root: str) -> None:
self._categories = categories
def __init__(self, root: str) -> None:
self._root = pathlib.Path(root)

@abstractmethod
Expand All @@ -38,7 +40,7 @@ def _dump(self, data: Data, file: pathlib.Path, suffix: str) -> None:
def _filter_annotations(self, data: Data) -> Data:
annotations = [
annotation for annotation in data['annotations']
if annotation['category_id'] < self._categories.num_bases
if annotation['category_id'] < self.CATEGORIES.num_bases
]
return data | dict(annotations=annotations)

Expand All @@ -57,7 +59,7 @@ def build(self, filename: str, min_: bool) -> None:
data = self._load(file)

category_oid2nid = { # nid = new id, oid = old id
category['id']: self._categories.all_.index(category['name'])
category['id']: self.CATEGORIES.all_.index(category['name'])
for category in data['categories']
}
self._map_category_ids(data, category_oid2nid)
Expand All @@ -66,22 +68,20 @@ def build(self, filename: str, min_: bool) -> None:
key=lambda category: category['id'],
)

self._dump(data, file, str(self._categories.num_all))
self._dump(data, file, str(self.CATEGORIES.num_all))
filtered_data = self._filter_annotations(data)
self._dump(filtered_data, file, str(self._categories.num_bases))
self._dump(filtered_data, file, str(self.CATEGORIES.num_bases))
if min_:
filtered_data = self._filter_images(data)
self._dump(filtered_data, file, f'{self._categories.num_all}.min')
self._dump(filtered_data, file, f'{self.CATEGORIES.num_all}.min')


class COCOBuilder(Builder):

def __init__(self, *args, **kwargs) -> None:
super().__init__(coco, *args, **kwargs)
CATEGORIES = coco

def _load(self, file: pathlib.Path) -> Data:
data = COCO(file)
category_ids = data.get_cat_ids(cat_names=self._categories.all_)
category_ids = data.get_cat_ids(cat_names=self.CATEGORIES.all_)
annotation_ids = data.get_ann_ids(cat_ids=category_ids)
image_ids = data.get_img_ids()
categories = data.load_cats(category_ids)
Expand All @@ -95,9 +95,7 @@ def _load(self, file: pathlib.Path) -> Data:


class LVISBuilder(Builder):

def __init__(self, *args, **kwargs) -> None:
super().__init__(lvis, *args, **kwargs)
CATEGORIES = lvis

def _load(self, file: pathlib.Path) -> Data:
data = LVIS(file)
Expand Down
7 changes: 5 additions & 2 deletions tools/sample_visual_category_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import random
from collections import defaultdict
from typing import TypedDict, cast
from typing import TypedDict

import todd
import todd.tasks.object_detection as od
Expand Down Expand Up @@ -81,7 +81,7 @@ def oake(args: argparse.Namespace) -> dict[str, torch.Tensor]:
batch['categories'].shape[0]
)
for tensor, category in zip(batch['tensors'], batch['categories']):
embeddings[category.item()](tensor)
embeddings[category.item()](tensor.clone())

categories = torch.load(data_root / 'categories.pth', 'cpu')
return {
Expand All @@ -94,6 +94,9 @@ def oake(args: argparse.Namespace) -> dict[str, torch.Tensor]:
def main() -> None:
args = parse_args()

# some coco categories are not in the sample images
assert args.dataset != 'coco'

oake_embeddings = oake(args)
sample_image_embeddings: dict[str, torch.Tensor] = torch.load(
f'work_dirs/sample_image_embeddings/{args.dataset}.pth',
Expand Down
11 changes: 11 additions & 0 deletions tools/torchrun.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
set -e

for port in {5000..5100}; do
if ! sudo netstat -tunlp | grep -w :${port} > /dev/null; then
break
fi
done

set -x

torchrun --nproc-per-node=$(nvidia-smi -L | wc -l) --master-port=${port} "$@"

0 comments on commit 8dbe866

Please sign in to comment.