Skip to content
/ SIMLA Public

[ECCV 22] Single Stream Multi-Level Alignment for Vision Language Pretraining

Notifications You must be signed in to change notification settings

codezakh/SIMLA

Repository files navigation

Conference Paper

SIMLA: Single-Stream Multi-Level Alignment for Vision-Language Pretraining, ECCV 2022 (NEC Labs)

This is the official PyTorch implementation of SIMLA. The repository is heavily based on salesforce/ALBEF, and supports vision-language pretraining and downstream task finetuning for several tasks.

Setup

Dependencies

conda env create --name simla --file environment.yaml

Data

See individual sections below for instructions.

Checkpoints

The checkpoints are around 3GB, and contain the optimizer state and everything else needed to resume training.

Pretraining

Downloading these exact datasets is unnecessary - the pretraining only requires image text pairs, so any image-text pair dataset will do.

  1. Download the weights of DALL-E's D-VAE (encoder, decoder), and place them in a folder.
  2. Edit configs/Pretrain.yaml and change image_tokenizer_path: /net/acadia10a/data/zkhan/dall-e-tokenizer-weights to the folder where you downloaded the dall-e tokenizer weights.
  3. Generate the pretraining JSON. You can download an example from ALBEF.
    • The JSON is a list of dictionaries, one for each image: {'image': '/absolute/path/to/image', 'caption': 'the caption of image'}.
    • We made a JSON file for each dataset we used (COCO, SBU, CC3M), but you can just have one file for all the image text-pairs.
  4. Edit configs/Pretrain.yaml and point it to your JSON, so train_file: /path/to/your/pretraining.json.
  5. Run the command below.
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config configs/Pretrain.yaml --output_dir <where to save> 

Pretraining on 8x A100s with a batch size of 64 for 30 epochs takes roughly 7 days, and uses about 73GB of GPU memory per GPU. If you're using A100s or A6000s, you may need to run export NCCL_P2P_DISABLE=1 in the shell before training.

Image Text Retrieval

  1. Download the JSON files for finetuning here.
  2. Next, download the COCO2017 train images and val images from the official website, and move all the images into one directory.

Finetuned (COCO)

Edit train_file, val_file and test_file in configs/Retrieval_coco.yaml to point to their respective JSON files you downloaded in Step 1. Note that the test annotations are not public, so we report results on the validation split following previous work.

python -m torch.distributed.launch --master_port=49770 --nproc_per_node=2 --use_env Retrieval.py \
--config ./configs/Retrieval_coco.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>

Zero-Shot (Flickr)

Download Flickr30k from Kaggle. Download the annotations for Flickr30k here. Edit train_file, val_file and test_file in configs/Retrieval_flickr.yaml to point to their respective JSON files you downloaded in Step 1. We do not use the validation split, so that key can be set to the name of the train or test file.

python -m torch.distributed.launch --master_port=47770 --nproc_per_node=2 --use_env zero_shot_retrieval.py --config ./configs/Retrieval_flickr.yaml --output_dir <where to save> --checkpoint <path of .pth checkpoint file> --evaluate

Finetuned (Flickr)

Same as the above.

python -m torch.distributed.launch --master_port=49770 --nproc_per_node=2 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>

RefCOCO+ (Visual Grounding)

python -m torch.distributed.launch --master_port=49121 --nproc_per_node=2 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir <path/to/output> \
--gradcam_mode itm \
--block_num 8 \
--checkpoint <path/to/checkpoint.pth> \

VQA (Visual Question Answering)

python -m torch.distributed.launch --nproc_per_node=2 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth> 

NLVR (Natural Language Visual Reasoning)

Pretraining

python -m torch.distributed.launch --nproc_per_node=2 --use_env Pretrain_NLVR.py \
--config ./configs/NLVR_Pretrain.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth> 

Finetuning

python -m torch.distributed.launch --nproc_per_node=2 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>

SNLI-VE (Visual Entailment)

python -m torch.distributed.launch --master_port=47770 --nproc_per_node=2 \
--use_env VE.py \
--config ./configs/VE.yaml \
--output_dir <path/to/output> \
--checkpoint <path/to/checkpoint.pth>

Citation

@inproceedings{SIMLA,
      title={Single-Stream Multi-Level Alignment for Vision-Language Pretraining, 
      author={Zaid Khan and Vijay Kumar BG and Xiang Yu and Samuel Schulter and Manmohan Chandraker and Yun Fu},
      year={2022},
      booktitle={ECCV}
}

About

[ECCV 22] Single Stream Multi-Level Alignment for Vision Language Pretraining

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages