This is the official PyTorch codes for the paper:
Learning Attention as Disentangler for Compositional Zero-shot Learning
Shaozhe Hao,
Kai Han,
Kwan-Yee K. Wong
CVPR 2023
Project Page | Code | Paper
TL;DR: A simple cross-attention mechanism is efficient to disentangle visual concepts, i.e., attribute and object concepts, enhancing CZSL performance.
Our work is implemented in PyTorch and tested with Ubuntu 18.04/20.04.
- Python 3.8
- PyTorch 1.11.0
Create a conda environment ade
using
conda env create -f environment.yml
conda activate ade
Datasets: We include a script to download all datasets used in our paper. You need to download any dataset before training the model. Please download datasets from: Clothing16K and Vaw-CZSL. You can download other datasets using
bash utils/download_data.sh
In our paper, we conduct experiments on Clothing16K, UT-Zappos50K, and C-GQA. In the supplementary material, we also add experiments on Vaw-CZSL.
Pretrained models: DINO pretrained ViT-B-16 can be found here. We also provide models pretrained on different datasets under closed-world or open-world settings. Please download the pretrained models and quickly start by testing their performance using
python test.py --log ckpt/MODEL_FOLDER
Train ADE model with a specified configure file CONFIG_FILE
(e.g, configs/clothing.yml
) using
python train.py --config CONFIG_FILE
After training, the logs
folder should be created with logs, configs, and checkpoints saved.
Test ADE model with a specified log folder LOG_FOLDER
(e.g, logs/ade_cw/clothing
) using
python test.py --log LOG_FOLDER
Conduct image-to-text retrieval with a specified log folder LOG_FOLDER
(e.g, logs/ade_ow/cgqa
) and a sample number SAMPLE_NUM
(default=100).
python img2txt_retrieval.py --log LOG_FOLDER --sample_num SAMPLE_NUM
Conduct text-to-image retrieval with a specified log folder LOG_FOLDER
(e.g, logs/ade_ow/cgqa
) and a given text prompt TEXT
(e.g., squatting catcher).
python txt2img_retrieval.py --log LOG_FOLDER --text_prompt TEXT
Conduct visual concept retrieval with a specified log folder LOG_FOLDER
(e.g, logs/ade_cw/clothing
) and a given image path IMG_PATH
(e.g., clothing/images/green_suit/000002.jpg
).
python visual_concept_retrieval.py --log LOG_FOLDER --image_path IMG_PATH
You can also adjust the coefficient args.aow
) for different datasets in retrieval tasks, referring to the chosen
Dataset | Closed-World | Open-World |
---|---|---|
Clothing |
|
|
UT-Zappos |
|
|
CGQA |
|
|
Dataset | AUCcw | HMcw | AUCow | HMow |
---|---|---|---|---|
Clothing | 92.4 | 88.7 | 68.0 | 74.2 |
UT-Zappos | 35.1 | 51.1 | 27.1 | 44.8 |
CGQA | 5.2 | 18.0 | 1.42 | 7.6 |
If you use this code in your research, please consider citing our paper:
@InProceedings{hao2023ade,
title={Learning Attention as Disentangler for Compositional Zero-shot Learning},
author={Hao, Shaozhe and Han, Kai and Wong, Kwan-Yee K.},
booktitle={CVPR},
year={2023}}
Our project is based on CZSL. Thanks for open source!