Skip to content
forked from NVlabs/VILA

VILA - a multi-image visual language model with training, inference and evaluation recipe, deployable from cloud to edge (Jetson Orin and laptops)

License

Notifications You must be signed in to change notification settings

yukang2017/VILA

Β 
Β 

Repository files navigation

VILA: On Pre-training for Visual Language Models

Code License Model License Python 3.10+

VILA arxiv / VILA Demo / VILA Huggingface

πŸ’‘ Introduction

VILA is a visual language model (VLM) pretrained with interleaved image-text data at scale, enabling multi-image VLM. VILA is deployable on the edge, including Jetson Orin and laptop by AWQ 4bit quantization through TinyChat framework. We find: (1) image-text pairs are not enough, interleaved image-text is essential; (2) unfreezing LLM during interleaved image-text pre-training enables in-context learning; (3)re-blending text-only instruction data is crucial to boost both VLM and text-only performance. VILA unveils appealing capabilities, including: multi-image reasoning, in-context learning, visual chain-of-thought, and better world knowledge.

πŸ’‘ News

  • [2024/03] VILA-2.7B is released! It can run on NVIDIA Jetson Orin Nano (Tutorial) and appeared at GTC 2024!
  • [2024/03] VILA is accepted by CVPR 2024!
  • [2024/02] We release AWQ-quantized 4bit VILA models, deployable on Jetson Orin and laptops through TinyChat and TinyChatEngine.
  • [2024/02] VILA is released. We propose interleaved image-text pretraining that enables multi-image VLM. VILA comes with impressive in-context learning capabilities. We open source everything: including training code, evaluation code, datasets, model ckpts.
  • [2023/12] Paper is on Arxiv!

Performance

$~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ Prec. VQAv2 GQA VizWiz SQA-I VQA-T POPE MME MMB MMB-CN SEED llava-bench MM-Vet Average (w/o MME)
VILA-7B fp16 80.3 63.1 59.6 68.0 62.6 86.3 1489.4 69.8 61.0 61.7 75.2 35.1 65.7
VILA-7B-AWQ int4 80.1 63.0 57.8 68.0 61.9 85.3 1486.3 68.8 59.0 61.3 75.8 35.9 65.2
VILA-13B fp16 80.5 63.6 63.1 70.5 64.0 86.3 1553.6 73.8 66.7 62.8 78.3 42.6 68.4
VILA-13B-AWQ int4 80.4 63.6 63.0 71.2 63.5 87.0 1552.9 73.6 66.3 62.2 77.6 42.0 68.2

NOTE: The benchmark results are slightly different from what we report in the paper due to refactoring of the codebase based on LLava-1.5 and re-train the model. VQAV2 and VizWiz are test-dev.

Inference speed ( Token/sec )

$~~~~~~$ Precision A100 4090 Orin
VILA-7B fp16 81.6 58.5 11.5
VILA-7B-AWQ int4 155.3 168.1 35.6
VILA-13B fp16 48.5 OOM 6.1
VILA-13B-AWQ int4 102.1 99.0 17.5

VILA Examples

In context learning

Multi-image reasoning

VILA on Jetson Orin

VILA-13B_Orin_deer.mp4.mp4

VILA on RTX 4090

vila_4090_two_cars_3x.mp4

Installation

./environment_setup.sh

or follow the instructions below in order.

conda create -n vila python=3.10 -y
conda activate vila

pip install --upgrade pip  # enable PEP 660 support
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.4.2/flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install -e .
pip install -e ".[train]"

pip install git+https://github.com/huggingface/[email protected]
cp -rv ./llava/train/transformers_replace/* ~/anaconda3/envs/vila/lib/python3.10/site-packages/transformers/models/

Training

VILA training contains three steps

Step-1: Alignment

We utilize LLaVA-CC3M-Pretrain-595K dataset to align the textual and visual modalities.

The stage 1 script takes in two parameters and it can run on a single 8xA100 node. BASE_MODEL_PATH points to a online or local huggingface repository, such as NousResearch/Llama-2-7b-hf. OUTPUT_NAME points to a target directory under checkpoints, which will save the trained multimodal projector afterwards.

bash scripts/v1_5/paper/1_mm_align.sh [BASE_MODEL_PATH] [OUTPUT_NAME]
Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
VILA-7B 256 2e-5 1 4096 0
VILA-13B 256 2e-5 1 4096 0

Step-2: Pretraining

We use MMC4 and Coyo dataset to train VLM with interleaved image-text pairs.

bash scripts/v1_5/paper/2_pretrain_mmc4_coyo.sh [CODE_PATH] [BASE_MODEL_PATH] [STAGE1_PATH] [OUTPUT_NAME]

The stage 2 script takes in four arguments. CODE_PATH is the absolute path to our VILA codebase, BASE_MODEL_PATH has similar meaning to what is presented in the stage 1 script. STAGE1_PATH points to the OUTPUT_NAME of stage 1 (i.e. where the stage 1 checkpoint is stored). OUTPUT_NAME is the desired folder name under checkpoints that saves the pretraining checkpoint. The script we provided for this stage is executed on slurm, and we expect it to execute on 16 nodes (128 GPUs).

Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
VILA-7B 1024 5e-5 1 4096 0
VILA-13B 1024 5e-5 1 4096 0

Step-3: Supervised fine-tuning

This is the last stage of VILA training, in which we tune the model to follow multimodal instructions on a subset of M3IT, FLAN and ShareGPT4V. This stage runs on a 8xA100 node.

bash scripts/v1_5/paper/3_sft.sh [STAGE2_PATH] [OUTPUT_NAME]

The stage 3 script takes in two arguments. STAGE2_PATH points to the OUTPUT_NAME of the stage 2 script (i.e. where the stage 2 checkpoint is stored). OUTPUT_NAME is the desired folder name under checkpoints that stores the final checkpoint.

Hyperparameter Global Batch Size Learning rate Epochs Max length Weight decay
VILA-7B 128 2e-5 1 4096 0
VILA-13B 128 2e-5 1 4096 0

Training with fewer GPUs

To train with fewer GPUs/nodes, you can reduce the per_device_train_batch_size and increase the gradient_accumulation_steps accordingly. As long as the global batch size same (per_device_train_batch_size x gradient_accumulation_steps x num_gpus) are kept the same, the training precision will not be affected.

Stage 1 completes within 3.5 (7B) - 5.5 (13B) hours on 8xA100, Stage 2 completes within 30 hours on 128xA100 for VILA-7B, and stage 3 completes in 25 (7B) - 40 (13B) hours on 8xA100.

See data_prepare/README.md for more information about how to prepare datasets.

Evaluations

You can follow Llava1.5 eval to download all datasets. After downloading all datasets, please put them under playground/data/eval.

We provide a push-the-button script to perform evaluation on all 10 datasets that do not require GPT-assisted evaluation:

./scripts/v1_5/eval/eval_all.sh [CHECKPOINT_PATH] [MODEL_NAME]

This script takes in two parameters, CHECKPOINT_PATH points to the stage 3 model checkpoint, and MODEL_NAME will be the name of evaluation results.

VQAv2 and Vizwiz evaluations are hosted on eval.ai. You need to register an account and create a team to be able to submit eval.

MMBench and MMBench_CN eval are hosted on another evaluation server. Make sure you change the name of the file before submitting, otherwise the server caches results and will always return wrong result to you.

We provide a quick script to automatically organize the prediction files that need to be submitted to servers:

python scripts/v1_5/eval/copy_predictions.py [MODEL_NAME]

You will be able to find the predictions under playground/data/predictions_upload/[MODEL_NAME] after executing this script.

Inference

We provide snippets for quick inference with user prompts and images.

VILA-7B inference:

python -W ignore llava/eval/run_llava.py \
    --model-path Efficient-Large-Model/VILA-7B \
    --conv-mode vicuna_v1 \
    --query "<image>\n Please describe the traffic condition." \
    --image-file "demo_images/av.png"

VILA-13B inference:

python -W ignore llava/eval/run_llava.py \
    --model-path Efficient-Large-Model/VILA-13B \
    --conv-mode vicuna_v1 \
    --query "<image>\n Please describe the traffic condition." \
    --image-file "demo_images/av.png"

Quantization and Deployment

Our VILA models are quantized by AWQ into 4 bits for efficient inference on the edge. We provide a push-the-button script to quantize VILA with AWQ.

Running VILA on desktop GPUs and edge GPUs

We support AWQ-quantized 4bit VILA on GPU platforms via TinyChat. We provide a tutorial to run the model with TinyChat after quantization. We also provide an instruction to launch a Gradio server (powered by TinyChat and AWQ) to serve 4-bit quantized VILA models.

Running VILA on laptops

We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our TinyChatEngine. We also provide a detailed tutorial to help the users deploy VILA on different CPUs.

Checkpoints

We release VILA-7B, VILA-13B, VILA-7B-4bit-AWQ and VILA-13B-4bit-AWQ.

πŸ”’ License

  • The code is released under the Apache 2.0 license as found in the LICENSE file.
  • The pretrained weights are released under the CC-BY-NC-SA-4.0 license.
  • The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:

Team

*Ji Lin: OpenAI (work done at Nvidia and MIT) *Hongxu Yin: Nvidia *Yao Lu: Nvidia
Wei Ping: Nvidia Pavlo Molchanov: Nvidia Andrew Tao: Nvidia
Haotian Tang: MIT Shang Yang: MIT Ligeng Zhu: Nvidia, MIT
Wei-Chen Wang: MIT Fuzhao Xue: Nvidia, NUS Yunhao Fang: Nvidia, UCSD
Yukang Chen: Nvidia, CUHK Yue Shen: Nvidia Huizi Mao: Nvidia
Jan Kautz: Nvidia Mohammad Shoeybi: Nvidia Song Han: Nvidia, MIT

Citations

@misc{lin2023vila,
      title={VILA: On Pre-training for Visual Language Models},
      author={Ji Lin and Hongxu Yin and Wei Ping and Yao Lu and Pavlo Molchanov and Andrew Tao and Huizi Mao and Jan Kautz and Mohammad Shoeybi and Song Han},
      year={2023},
      eprint={2312.07533},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

About

VILA - a multi-image visual language model with training, inference and evaluation recipe, deployable from cloud to edge (Jetson Orin and laptops)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.3%
  • Shell 2.7%