This is the official PyTorch implementation of SIMLA. The repository is heavily based on salesforce/ALBEF, and supports vision-language pretraining and downstream task finetuning for several tasks.
conda env create --name simla --file environment.yaml
See individual sections below for instructions.
- pretrained on 4m images
- Use this one if you want to finetune the model for another downstream VL task, like VQA.
- finetuned on COCO
- Use this one for retrieval tasks.
The checkpoints are around 3GB, and contain the optimizer state and everything else needed to resume training.
Downloading these exact datasets is unnecessary - the pretraining only requires image text pairs, so any image-text pair dataset will do.
- Download COCO from the official website (use COCO2014, download it all).
- Download SBU captions using Huggingface.
- Download Conceptual Captions using Huggingface.
- Download the weights of DALL-E's D-VAE (encoder, decoder), and place them in a folder.
- Edit
configs/Pretrain.yaml
and changeimage_tokenizer_path: /net/acadia10a/data/zkhan/dall-e-tokenizer-weights
to the folder where you downloaded the dall-e tokenizer weights. - Generate the pretraining JSON. You can download an example from ALBEF.
- The JSON is a list of dictionaries, one for each image:
{'image': '/absolute/path/to/image', 'caption': 'the caption of image'}
. - We made a JSON file for each dataset we used (COCO, SBU, CC3M), but you can just have one file for all the image text-pairs.
- The JSON is a list of dictionaries, one for each image:
- Edit
configs/Pretrain.yaml
and point it to your JSON, sotrain_file: /path/to/your/pretraining.json
. - Run the command below.
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config configs/Pretrain.yaml --output_dir <where to save>
Pretraining on 8x A100s with a batch size of 64 for 30 epochs takes roughly 7 days, and uses about 73GB of GPU memory per GPU.
If you're using A100s or A6000s, you may need to run export NCCL_P2P_DISABLE=1
in the shell before training.
- Download the JSON files for finetuning here.
- Next, download the COCO2017 train images and val images from the official website, and move all the images into one directory.
Edit train_file
, val_file
and test_file
in configs/Retrieval_coco.yaml
to point to their respective JSON files you downloaded in Step 1.
Note that the test annotations are not public, so we report results on the validation split following previous work.
python -m torch.distributed.launch --master_port=49770 --nproc_per_node=2 --use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
Download Flickr30k from Kaggle.
Download the annotations for Flickr30k here.
Edit train_file
, val_file
and test_file
in configs/Retrieval_flickr.yaml
to point to their respective JSON files you downloaded in Step 1.
We do not use the validation split, so that key can be set to the name of the train or test file.
python -m torch.distributed.launch --master_port=47770 --nproc_per_node=2 --use_env zero_shot_retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir <where to save> --checkpoint <path of .pth checkpoint file> --evaluate
Same as the above.
python -m torch.distributed.launch --master_port=49770 --nproc_per_node=2 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
python -m torch.distributed.launch --master_port=49121 --nproc_per_node=2 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir <path/to/output> \
--gradcam_mode itm \
--block_num 8 \
--checkpoint <path/to/checkpoint.pth> \
python -m torch.distributed.launch --nproc_per_node=2 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
python -m torch.distributed.launch --nproc_per_node=2 --use_env Pretrain_NLVR.py \
--config ./configs/NLVR_Pretrain.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
python -m torch.distributed.launch --nproc_per_node=2 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
python -m torch.distributed.launch --master_port=47770 --nproc_per_node=2 \
--use_env VE.py \
--config ./configs/VE.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>
@inproceedings{SIMLA,
title={Single-Stream Multi-Level Alignment for Vision-Language Pretraining,
author={Zaid Khan and Vijay Kumar BG and Xiang Yu and Samuel Schulter and Manmohan Chandraker and Yun Fu},
year={2022},
booktitle={ECCV}
}