Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed-VisualChat #753

Merged
merged 8 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions applications/DeepSpeed-VisualChat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
## DeepSpeed-VisualChat: Enabling Multi-Round Multi-Image Chat for All Scales

An easy-to-use, scalable, and efficient multi-modal training pipeline for multi-round multi-image interleave chat experience.


## Table of Contents

- [📰 Latest News 📰](#-latest-news-)
- [🚀 What is DeepSpeed-VisualChat 🚀️](#-what-is-deepspeed-visualchat-)
- [⚓ Get Started, Tutorial, and Documentation ⚓](#-get-started-tutorial-documentation-)
- [🌱 DeepSpeed-VisualChat's Roadmap 🌱](#-deepspeed-visualchats-roadmap-)
- [💬 DeepSpeed-VisualChat and DeepSpeed Community 💬](#-deepspeed-visualchat-and-deepspeed-community-)
- [🙏 Acknowledgement and Citation 🙏](#-acknowledgement-and-citation-)

<!-- markdown-toc end -->

## 📰 Latest News 📰

* ***[2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)***

⭐ If you find our [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) repositories beneficial, please give them a star on GitHub! To cite DeepSpeed-VisualChat, please cite our [arxiv report](https://arxiv.org/abs/2309.14327):

```
@article{yao2023deepspeed-visualchat,
title={{DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention}},
author={Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qin and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He},
journal={arXiv preprint arXiv:2309.14327},
year={2023}
}
```

## 🚀 What is DeepSpeed-VisualChat 🚀
<div align="center">

<img src="assets/hero-figure.png" alt="DeepSpeed-VisualChat Banner!"/>
Figure 1. On the left is a DeepSpeed-VisualChat model, featuring an innovative attention design. On the right is an example of DeepSpeed-VisualChat.

</div>

---

With increasing interest in enabling the multi-modal capabilities of large language models, DeepSpeed is proud to announce a new training pipeline, named ***DeepSpeed-VisualChat***. This is designed for enabling a multi-round, multi-image interleave chat framework. It enhances the language model with image understanding and reasoning capabilities. Unlike the majority of open-sourced multi-modal projects, the primary focus of DeepSpeed-VisualChat is to provide a multi-round, multi-image interleave chat experience, as illustrated in Figure 1.

To improve model quality without introducing new parameters, DeepSpeed-VisualChat incorporates a new multi-modal causal attention mechanism, which is adept at better aligning visual and text features. Additionally, to overcome the scarcity of interleaved text-and-image inputs in most available open-sourced datasets, we employ various data blending techniques on existing datasets.

Thanks to the scalable, efficient, and user-friendly nature of the DeepSpeed ecosystem, we have the capability to train using a 2B visual encoder from QWen-VL (one is additionally refined from OpenClip) and a 70B language decoder from LLaMA-2. This showcases the extraordinary scalability of the DeepSpeed-VisualChat framework.





## ⚓ Get Started, Tutorial, and Documents ⚓

### 🐼 Installation


```bash
git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/applications/DeepSpeed-VisualChat/
pip install -r requirements.txt
```

### 🐼 Datasets Preparation

Table below summarizes where to download the datasets that we support. `{data_path}` denotes the `--data_path` argument provided in training scripts.

| Dataset name | Where to download |
|--------------|-------------------|
| aokvqa | Download `2017 Train images [118K/18GB]` from [https://cocodataset.org/#download](https://cocodataset.org/#download) and save at `{data_path}/coco/train2017/`. Download `aokvqa_v1p0_train.json` from [https://allenai.org/project/a-okvqa/home](https://allenai.org/project/a-okvqa/home) and save at `{data_path}/aokvqa/annotations/`. |
| coco_caption | Download 2014 Train images and 2014 Val images from [https://cocodataset.org/#download](https://cocodataset.org/#download) and save all images at `{data_path}/coco/2014/`. Download `dataset.json` from [https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip) and save at `{data_path}/coco_caption/`. |
| llava | Download `2017 Train images [118K/18GB]` from [https://cocodataset.org/#download](https://cocodataset.org/#download) and save at `{data_path}/coco/train2017/`. Download `detail_23k.json` and `complex_reasoning_77k.json` from [https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and save at `{data_path}/llava/`. |
| llava_dial | Download `2017 Train images [118K/18GB]` from [https://cocodataset.org/#download](https://cocodataset.org/#download) and save at `{data_path}/coco/train2017/`. Download `conversation_58k.json` from [https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and save at `{data_path}/llava/`. |
| llava_otter_blend | Follow instructions of the llava, llava_dial, and otter_mimicit_cgd datasets. |
| minigpt4 | Download `image` folder and `filter_cap.json` from [https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) and save at `{data_path}/cc_sbu_align/`. |
| ocr_vqa | Download `images` folder and `dataset.json` from [https://ocr-vqa.github.io/](https://ocr-vqa.github.io/) and save at `{data_path}/OCR_VQA/`. |
| otter_mimicit_cgd | Download `2017 Train images [118K/18GB]` from [https://cocodataset.org/#download](https://cocodataset.org/#download) and save at `{data_path}/coco/train2017/`. Download `CGD_instructions.json` from [https://huggingface.co/datasets/pufanyi/MIMICIT](https://huggingface.co/datasets/pufanyi/MIMICIT) and save at `{data_path}/MIMIC-IT/`. |
| otter_mimicit_sd | Download `SD.json` and `SD_instructions.json` from [https://huggingface.co/datasets/pufanyi/MIMICIT](https://huggingface.co/datasets/pufanyi/MIMICIT) and save at `{data_path}/MIMIC-IT/`. |
| otter_mimicit_sn | Download `SN.json` and `SN_instructions.json` from [https://huggingface.co/datasets/pufanyi/MIMICIT](https://huggingface.co/datasets/pufanyi/MIMICIT) and save at `{data_path}/MIMIC-IT/`. |
| otter_mimicit_tvc | Download `TVC.json` and `TVC_instructions.json` from [https://huggingface.co/datasets/pufanyi/MIMICIT](https://huggingface.co/datasets/pufanyi/MIMICIT) and save at `{data_path}/MIMIC-IT/`. |
| otter_mimicit_vst | Download `VST.json` and `VST_instructions.json` from [https://huggingface.co/datasets/pufanyi/MIMICIT](https://huggingface.co/datasets/pufanyi/MIMICIT) and save at `{data_path}/MIMIC-IT/`. |
| sparkles_dialogue | Download the `SparklesDialogueCC` and `SparklesDialogueVG` folders from the OneDrive link from [https://github.com/HYPJUDY/Sparkles](https://github.com/HYPJUDY/Sparkles) and save at `{data_path}/`. |

### 🐼 Training, Evaluation, Chat API, and Helper
Please refer to
- [**Training**](./training/README.md)
- [**Evaluation**](./eval/README.md)
- [**Chat**](./chat/README.md)
- [**Helper**](./helper/README.md)


## 🌱 DeepSpeed-VisualChat's Roadmap 🌱

Our future plan includes but not limited to :
- [ ] Support more models
- [ ] Demonstrate how to training larger models with higher model quality

## 💬 DeepSpeed-VisualChat and DeepSpeed Community 💬

Just like how the success of [the BLOOM model](https://huggingface.co/bigscience/bloom) was supported by both [DeepSpeed Team](https://github.com/bigscience-workshop/Megatron-DeepSpeed) and many [open source contributors](https://huggingface.co/bigscience), we welcome all AI developers/practitioners/researchers to join this on-going effort for DeepSpeed-Chat. To participate:
- Show your support by leaving a star ⭐ to our [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples) GitHub repositories.
- Follow us on [twitter](https://twitter.com/MSFTDeepSpeed) to get notified about our latest news. For Chinese users, you can also follow our [Chinese Zhihu account](https://www.zhihu.com/people/deepspeed). For Japanese users, you can also follow our [Japanese twitter account](https://twitter.com/MSFTDeepSpeedJP).
- Currently we prefer to interact with open source users mainly on GitHub so that it's easier for all users to search for related information. For bug reports, please submit a GitHub issue. For contribution, please submit a pull request (PR). For general question/discussion, please open a new discussion or join any existing discussions.
- We are open to collaborations with universities, research labs, and companies, such as working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to [email protected].


## 🙏 Acknowledgement and Citation 🙏

We thank the following papers and open-source repositories:

[1] LLaVa, https://github.com/haotian-liu/LLaVA
[2] Otter, https://github.com/Luodian/Otter
[3] Transformers Hugging Face, https://github.com/huggingface/transformers
[4] MiniGPT4, https://github.com/Vision-CAIR/MiniGPT-4
[5] QWen-VL, https://github.com/QwenLM/Qwen-VL
[6] Sparkles, https://github.com/HYPJUDY/Sparkles
[7] Multimodal-GPT, https://github.com/open-mmlab/Multimodal-GPT
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions applications/DeepSpeed-VisualChat/chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
We provide a CLI interface for uses to test their trained chat model. First of all, please note that you need to provide both the trained checkpoint and the original language model & vision encoder paths. The model is first initialized and then loads the trained checkpoint. Also, please note that if you used multi-modal causal attention during your training, remember to put --enable_mmca_attention for in your chat script.
169 changes: 169 additions & 0 deletions applications/DeepSpeed-VisualChat/chat/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import argparse
import os
import sys
from PIL import Image

import torch
import deepspeed

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from utils.utils import print_rank_0
from utils.model import create_dsvl_model_and_transforms
import utils.data.DST as DST
from transformers import AutoTokenizer
from termcolor import colored
import re

def parse_args():
parser = argparse.ArgumentParser(
description=
"CLI chat")
parser.add_argument(
"--lm_model_name_or_path",
type=str,
help=
"Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument("--vision_model_name_or_path", default="openai/clip-vit-large-patch14", type=str)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
help="path to pretrained model",
)
parser.add_argument(
"--max_seq_len",
type=int,
default=4096,
help="The maximum sequence length.",
)
parser.add_argument(
"--generation_length_per_round",
type=int,
default=256,
help="The generation length per conversation round.",
)
parser.add_argument(
"--enable_mmca_attention",
action='store_true',
help="enable the new proposed attn, which is similar to cross attention",
)
parser.add_argument(
"--vis_proj",
type=str,
default='baseline',
help="baseline, vit, or perceiver",
)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()

return args


def get_user_text_input():
tmp = input(colored("Enter input (type 'quit' to exit, 'clear' to clean memory): ", 'green'))
return tmp, tmp == "quit", tmp == "clear"

def get_user_image_input():
tmp = input(colored("Enter image pathes, seperate by space (only support one image per time for now) (type 'na' for empty image): ", 'blue'))
return tmp, not tmp == "na"

def main():
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.lm_model_name_or_path,
fast_tokenizer=True)
tokenizer.padding_side = 'right'
model, image_processor, tokenizer = create_dsvl_model_and_transforms(
text_tokenizer = tokenizer,
ds_config=None,
args=args,
)

model.load_state_dict(torch.load(os.path.join(args.checkpoint_path, 'pytorch_model.bin'), map_location='cpu'), strict=False) # Z3 wouldn't save pos embeddings (vis and rope)

model = model.eval()
model.projection = model.projection.to('cuda')
model.vis_encoder = model.vis_encoder.to('cuda')
model = model.half()
print_rank_0(model)

num_rounds = 0
images = []
system_instruct = []
TEMPLATE = DST.Prompter() # get template
image_num_token_list = [DST.IMAGE_NUM_1, DST.IMAGE_NUM_2, DST.IMAGE_NUM_3, DST.IMAGE_NUM_4, DST.IMAGE_NUM_5, DST.IMAGE_NUM_6, DST.IMAGE_NUM_7, DST.IMAGE_NUM_8]

while True:
num_rounds += 1
while True:
# it is super easy to make mistake here, so we need to be careful
image_input, with_image = get_user_image_input()
if with_image:
try:
# seperate by space
image_paths = image_input.split(' ')
tmp_images = []
for image_path in image_paths:
image = Image.open(image_path).convert('RGB')
tmp_image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().half()
tmp_images.append(tmp_image_tensor) # in case the last image path is wrong
except:
print(colored("Invalid image path, please try again", 'red'))
continue
if len(images) + len(tmp_images) > 8:
print(colored("Too many images, we at most support 8 images. please try again", 'red'))
continue
images = images + tmp_images # get all images
image_num = len(tmp_images)
break
else:
image_num = 0
break
assert len(images) >= 1, "We need at least one image to begin the conversation for now."
if len(images) > 0:
image_tensor = torch.cat(images, dim=0) # cat all images
else:
image_tensor = None

text_input, quit, clear = get_user_text_input()
if quit:
break
if clear:
num_rounds = 0
images = []
system_instruct = []
image_num_token_list = [DST.IMAGE_NUM_1, DST.IMAGE_NUM_2, DST.IMAGE_NUM_3, DST.IMAGE_NUM_4, DST.IMAGE_NUM_5, DST.IMAGE_NUM_6, DST.IMAGE_NUM_7, DST.IMAGE_NUM_8]
continue


full_prompt = TEMPLATE(text_input, with_image=with_image, first_message=(num_rounds==1), num_images=image_num)
if with_image:
for i in range(image_num):
full_prompt = re.sub(DST.DEFAULT_HUMAN_IMAGE_PRETOKEN, image_num_token_list.pop(0), full_prompt, count=1)


full_prompt_ids = tokenizer(full_prompt).input_ids # remove bos token

input_ids = torch.as_tensor([system_instruct + full_prompt_ids]).cuda() # entire input as system instruction for simplicity
generate_output = model.generate(image_tensor, input_ids, generation_length=args.generation_length_per_round)
extend_ids = generate_output[0].cpu().tolist()[0]
while extend_ids[-1] == tokenizer.pad_token_id:
extend_ids.pop()
while extend_ids[0] == tokenizer.bos_token_id:
extend_ids.pop(0)
system_instruct = system_instruct + full_prompt_ids + extend_ids # entire input as system instruction for simplicity
system_instruct = system_instruct + [tokenizer.eos_token_id] # add eos token

print(f"=========== Round {num_rounds} ===========")
print(tokenizer.decode(system_instruct))


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions applications/DeepSpeed-VisualChat/chat/chat_scripts/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
MAIN_PATH=$1

VISION_ENCODER=/blob/transformers_cache/qwen-clip
LLM=/blob/transformers_cache/Llama-2-13b-hf

export CUDA_VISIBLE_DEVICES=0 # Do multi single evaluation
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # Do multi gpu evaluation for large models (single GPU is not enough)


python chat.py \
--lm_model_name_or_path $LLM \
--vision_model_name_or_path $VISION_ENCODER \
--checkpoint_path $MAIN_PATH --enable_mmca_attention
28 changes: 28 additions & 0 deletions applications/DeepSpeed-VisualChat/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
### ☀️Evaluation
We provide a few examples to test the quality of the models.
To run the tests, use the `batch_generation.py` script, which will call the JSON file located in `eval_data/*.json`.
You will need to specify the model path where you've saved your checkpoints. For example, if you've saved your model checkpoint at $YOUR_CHECKPOINT_PATH/epoch-5/pytorch_model.bin, then pass the following arguments:
```
--checkpoint_path $YOUR_CHECKPOINT_PATH --checkpoint_names epoch-5
```

##### 🏃 Run the Code
NOTE: Before you run the code `run_batch.sh`, please read it carefully. This bash script creates a folder eval/results/eval_comprehensive if you use the json evaluation "eval_comprehensive". It will write to "eval/results/eval_comprehensive/{args.output_filename}.csv" file with four columns. The generation output is in the last column. Please read one of our examples such as `eval/results/eval_comprehensive/ours-set1_final.csv`.
To run the code, you need to go to outside the current folder
```
cd DeeSpeedExamples/applications/DeepSpeed-VisualChat
bash eval/run_batch.sh
```


#### 🐕 Our Model Results Overview
We present the outcomes from our three distinct models, each trained with vision encoders: `qwen-clip` and `Llama-2-13b-hf`.

###### Results Directories and Training Details:
- **results/eval_single:**
This directory contains results from the model trained with LoRA, featuring a dimension size of 128.

- **results/eval_comprehensive** and **results/eval_robustness:**
These directories host results from two models:
- One model is trained excluding the Sparkles dataset (referred to as `ours-set1`).
- The other incorporates Sparkles dataset in the training (denoted as `ours-set2`).
Loading
Loading