Our model supports image caption, VQA, especially performs well on OCR-related images. Our model supports Chinese Simplified, Chinese Traditional, and English languages. 🤩🤩🤩 Please give me a star if you find it interesting and useful! [🤗Space]
Diagram of OCR_MLLM_TOY Model.
- [1] OCR image encoder is adopted from an end to end OCR recognition model, here we adopted the pretrain weight from vary
- [2] We adopted VIT image encoder weight from QwenVL
- [2024/03/08] 🔥 We released the OCR_MLLM_TOY.
- [2024/03/07] 🔥 We released the training and evaluation code.
Usage and License Notices: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Qwen, Vary
To run our gradio demo, you need to get the checkpoints from huggingface and put them in "./checkpoints/qwen14b-finetune_all/checkpoint-8300". Then run the following commands.
python -m ./ocr_mllm_gradio/my_gradio_web_server.py --host 0.0.0.0 --port 10000
- Clone this repository
git clone https://github.com/SuXuping/OCR_MLLM_TOY.git
- Install Package
conda create -n OCR_MLLM_TOY python=3.10 -y
conda activate OCR_MLLM_TOY
pip install --upgrade pip
pip install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/
- Install additional packages for training cases
pip install ninja
pip install flash-attn --no-build-isolation
pip install deepspeed
OCR_MLLM_TOY is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can:
- [1] reduce the
per_device_train_batch_size
and increase thegradient_accumulation_steps
accordingly. - [2] use lora during training.
- [3] use LLM(7B) instead.
Prepare your data in this format:
[
{
"id": "image_id",
"image": "image path",
"conversations": [
{
"from": "human",
"value": "Examine the image closely and share its details<image>\n"
},
{
"from": "gpt",
"value": "The image shows a man fishing on a lawn next to a river with a bridge in the background. Trees can be seen on the other side of the river, and the sky is cloudy."
}
]
}
]
We use a similar set of hyperparameters as LLaVA in pretraining and SFT. more details refer to pretrain and SFT
Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | deepspeed |
---|---|---|---|---|---|---|
pretrain | 1024 | 1e-3 | 1 | 1024 | 0 | zero2 |
SFT | 64 | 2e-5 | 1 | 2048 | 0 | zero3 |
Before you start train you own MLLM, yuu need prepare some weights:
- [1] prepare your base LLM and put the weights in "./ocr_mllm_toy/pretrain_weight/qwen_pretrain"
- [2] prepare vit image encoder 448 and put the weights in "./ocr_mllm_toy/pretrain_weight/qwen_vit_448"
- [3] prepare ocr image encoder 1024 and put the weights in "./ocr_mllm_toy/pretrain_weight/vary_pretrain"
The pretrain scripts are provided
sh ./scripts/pretrain_qwen14b.sh
The SFT scripts are provided
sh ./scripts/finetune_lora_qwen14b.sh
We have evaluated OCR_MLLM_TOY on many benchmarks including TextVQA\mm_bench\mm_bench_cn\mm_vet\MME. In some benchmarks our model can achieve similar results as LLaVA Next-34B.
Evaluation results of OCR_MLLM_TOY14B and LLaVA13B model on mm-vet benchmark.
Please see this doc for the details.
To run our inference, you need get our weights from here huggingface, and put them in "./checkpoints/qwen14b-finetune_all/checkpoint-8300" . We provide interactive api for multimodal infence with stream outputs.
python cli.py
We also provide LLM only infence and multimodal infence in "infence.py".
python inference.py