Skip to content

Commit

Permalink
feat: build_prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Oct 16, 2024
1 parent 65685d6 commit 9796e86
Showing 1 changed file with 68 additions and 37 deletions.
105 changes: 68 additions & 37 deletions tools/build_prompts.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
import clip
import clip.model
import einops
import torch
import torch.nn.functional as F
import tqdm
from mmdet.datasets import LVISV1Dataset
# from mmdet.datasets import LVISV1Dataset
import todd.tasks.natural_language_processing as nlp
from todd.models.modules import CLIPText

from oadp.categories import coco, lvis, objects365


def vild() -> None:
prompts = [
"This is a {}", "There is a {}", "a photo of a {} in the scene",
"This is a {}",
"There is a {}",
"a photo of a {} in the scene",
"a photo of a small {} in the scene",
"a photo of a medium {} in the scene",
"a photo of a large {} in the scene", "a photo of a {}",
"a photo of a small {}", "a photo of a medium {}",
"a photo of a large {}", "This is a photo of a {}",
"This is a photo of a small {}", "This is a photo of a medium {}",
"This is a photo of a large {}", "There is a {} in the scene",
"There is the {} in the scene", "There is one {} in the scene",
"This is a {} in the scene", "This is the {} in the scene",
"This is one {} in the scene", "This is one small {} in the scene",
"a photo of a large {} in the scene",
"a photo of a {}",
"a photo of a small {}",
"a photo of a medium {}",
"a photo of a large {}",
"This is a photo of a {}",
"This is a photo of a small {}",
"This is a photo of a medium {}",
"This is a photo of a large {}",
"There is a {} in the scene",
"There is the {} in the scene",
"There is one {} in the scene",
"This is a {} in the scene",
"This is the {} in the scene",
"This is one {} in the scene",
"This is one small {} in the scene",
"This is one medium {} in the scene",
"This is one large {} in the scene",
"There is a small {} in the scene",
"There is a medium {} in the scene",
"There is a large {} in the scene", "There is a {} in the photo",
"There is the {} in the photo", "There is one {} in the photo",
"There is a large {} in the scene",
"There is a {} in the photo",
"There is the {} in the photo",
"There is one {} in the photo",
"There is a small {} in the photo",
"There is the small {} in the photo",
"There is one small {} in the photo",
Expand All @@ -36,8 +48,10 @@ def vild() -> None:
"There is one medium {} in the photo",
"There is a large {} in the photo",
"There is the large {} in the photo",
"There is one large {} in the photo", "There is a {} in the picture",
"There is the {} in the picture", "There is one {} in the picture",
"There is one large {} in the photo",
"There is a {} in the picture",
"There is the {} in the picture",
"There is one {} in the picture",
"There is a small {} in the picture",
"There is the small {} in the picture",
"There is one small {} in the picture",
Expand All @@ -46,16 +60,22 @@ def vild() -> None:
"There is one medium {} in the picture",
"There is a large {} in the picture",
"There is the large {} in the picture",
"There is one large {} in the picture", "This is a {} in the photo",
"This is the {} in the photo", "This is one {} in the photo",
"This is a small {} in the photo", "This is the small {} in the photo",
"There is one large {} in the picture",
"This is a {} in the photo",
"This is the {} in the photo",
"This is one {} in the photo",
"This is a small {} in the photo",
"This is the small {} in the photo",
"This is one small {} in the photo",
"This is a medium {} in the photo",
"This is the medium {} in the photo",
"This is one medium {} in the photo",
"This is a large {} in the photo", "This is the large {} in the photo",
"This is one large {} in the photo", "This is a {} in the picture",
"This is the {} in the picture", "This is one {} in the picture",
"This is a large {} in the photo",
"This is the large {} in the photo",
"This is one large {} in the photo",
"This is a {} in the picture",
"This is the {} in the picture",
"This is one {} in the picture",
"This is a small {} in the picture",
"This is the small {} in the picture",
"This is one small {} in the picture",
Expand All @@ -64,19 +84,30 @@ def vild() -> None:
"This is one medium {} in the picture",
"This is a large {} in the picture",
"This is the large {} in the picture",
"This is one large {} in the picture"
"This is one large {} in the picture",
]

tokenizer = nlp.tokenizers.CLIPTokenizer(
bpe_path='pretrained/clip/clip_bpe.txt.gz',
)

model = CLIPText(out_features=512)
model.load_pretrained('pretrained/clip/ViT-B-32.pt')
model.requires_grad_(False)
model.eval()
model.cuda()

names = sorted(set(coco.all_ + lvis.all_ + objects365.all_))
model, _ = clip.load_default()

embeddings = []
with torch.no_grad():
for prompt in tqdm.tqdm(prompts):
texts = map(prompt.format, names)
tokens = clip.adaptively_tokenize(texts)
embedding = model.encode_text(tokens)
embeddings.append(embedding)
tokens = tokenizer.encodes(texts)
tokens = tokens.cuda()
x = model(tokens)
eos = CLIPText.eos(tokens, x)
embeddings.append(eos)
embeddings_ = torch.stack(embeddings)
embeddings_ = F.normalize(embeddings_, dim=-1)
embeddings_ = einops.reduce(embeddings_, 'n ... -> ...', 'mean')
Expand All @@ -85,22 +116,22 @@ def vild() -> None:
torch.save(state_dict, 'data/prompts/vild.pth')


def detpro() -> None:
embeddings = torch.load('pretrained/detpro/iou_neg5_ens.pth', 'cpu')
# def detpro() -> None:
# embeddings = torch.load('pretrained/detpro/iou_neg5_ens.pth', 'cpu')

# lvis annotations have a typo, which is fixed in mmdet
# we need to change it back, so that the names match
names: list[str] = list(LVISV1Dataset.METAINFO['classes'])
i = names.index('speaker_(stereo_equipment)')
names[i] = 'speaker_(stero_equipment)'
# # lvis annotations have a typo, which is fixed in mmdet
# # we need to change it back, so that the names match
# names: list[str] = list(LVISV1Dataset.METAINFO['classes'])
# i = names.index('speaker_(stereo_equipment)')
# names[i] = 'speaker_(stero_equipment)'

state_dict = dict(embeddings=embeddings, names=names)
torch.save(state_dict, 'data/prompts/detpro_lvis.pth')
# state_dict = dict(embeddings=embeddings, names=names)
# torch.save(state_dict, 'data/prompts/detpro_lvis.pth')


def main() -> None:
vild()
detpro()
# detpro()


if __name__ == '__main__':
Expand Down

0 comments on commit 9796e86

Please sign in to comment.