Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SFT for VLM example #1865

Merged
merged 16 commits into from
Aug 2, 2024
124 changes: 118 additions & 6 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -533,12 +533,12 @@ Note however, that the amount of performance gain is _dataset dependent_ and in

You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:

| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|-----------------|-----------|-----|-------------------------|-----------------|----------------|
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |

First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:

Expand Down Expand Up @@ -621,6 +621,118 @@ model = AutoModelForCausalLM.from_pretrained(

You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.

## Extending `SFTTrainer` for Vision Language Models

`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.

### Preparing the Data

The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:

```python
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
```

To illustrate how this data format will be processed using the LLaVA model, you can use the following code:

```python
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
```

The output will be formatted as follows:

```txt
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
```

<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>


### A custom collator for processing multi-modal data

Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:

```python
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]

# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch
```

We can verify that the collator works as expected by running the following code:

```python
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
```

### Training the vision-language model

Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `dataset_text_field` and `remove_unused_columns`. We also need to set `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.

```python
args.dataset_text_field = "" # needs a dummy field
args.remove_unused_columns = False
args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
model=model,
args=args,
data_collator=collate_fn,
train_dataset=train_dataset,
tokenizer=processor.tokenizer,
)
```

A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py).

- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)

## SFTTrainer

[[autodoc]] SFTTrainer
Expand Down
113 changes: 29 additions & 84 deletions examples/scripts/vsft_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/vsft_llava.py \
--dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=5 \
--num_train_epochs=1 \
--push_to_hub \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True

# peft:
pip install pillow

python examples/scripts/vsft_llava.py \
--dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=5 \
--num_train_epochs=1 \
--push_to_hub \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--output_dir sft-llava-1.5-7b-hf \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True \
--use_peft=True \
--lora_r=64 \
--lora_alpha=16 \
--lora_target_modules=all-linear"

# evaluation:

To evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
then run:
accelerate launch --num_processes=8 -m lmms_eval \
--model llava_hf \
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
--tasks mmbench \
--batch_size 1 \
--output_path ./logs/ \
--log_sample
--use_peft \
--dataloader_num_workers 32 \
--lora_target_modules=all-linear
"""

import logging
Expand All @@ -85,7 +50,7 @@
from datasets import load_dataset

from tqdm.rich import tqdm
from transformers import AutoTokenizer, AutoProcessor, LlavaForConditionalGeneration
from transformers import AutoProcessor, LlavaForConditionalGeneration

from trl import (
ModelConfig,
Expand All @@ -107,6 +72,9 @@
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
sft_script_args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.dataset_text_field = "" # need a dummy field
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
Expand All @@ -115,8 +83,6 @@
################
# Model, Tokenizer & Processor
################
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""

torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
Expand All @@ -130,14 +96,9 @@
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
processor.tokenizer = tokenizer

model = LlavaForConditionalGeneration.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
Expand All @@ -146,34 +107,20 @@
################
# Create a data collator to encode text and image pairs
################
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]

class LLavaDataCollator:
def __init__(self, processor):
self.processor = processor

def __call__(self, examples):
texts = []
images = []
for example in examples:
if len(example["images"]) > 1:
raise ValueError("This collator only supports one image per example")
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])

batch = self.processor(texts, images, return_tensors="pt", padding=True)

labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)

return batch
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

data_collator = LLavaDataCollator(processor)
return batch

################
# Dataset
Expand All @@ -199,14 +146,12 @@ def __call__(self, examples):
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text", # need a dummy field
tokenizer=tokenizer,
tokenizer=processor.tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
data_collator=data_collator,
dataset_kwargs={"skip_prepare_dataset": True},
)

trainer.train()
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def __init__(
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
will be preprocessed by removing the columns that are not used by the model. If none is passed,
a warning will be raised in a multi-GPU setting.
optimizer (Optional[`torch.optim.Optimizer`]):
optimizer (`Optional[torch.optim.Optimizer]`):
Optimizer used for training. If `None`, the `Adam` is used as default.
data_collator (Optional[function]):
Data collator function that is going to be used for `prepare_dataloader` method. Note this collator
is different from the one we use for training. Pass a valid `training_data_collator` instead.
num_shared_layers (Optional[int]):
Number of shared layers between the model and the reference model. If `None`, all layers are shared.
used only if `ref_model` is `None`.
lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
lr_scheduler (`Optional[torch.optim.lr_scheduler]`):
Learning rate scheduler used for training.
training_data_collator (Optional[function]):
Custom data collator used for training.
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ class SFTTrainer(Trainer):
The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to
load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is
passed to the `peft_config` argument.
args (Optional[`SFTConfig`]):
args (`Optional[SFTConfig]`):
The arguments to tweak for training. Will default to a basic instance of [`SFTConfig`] with the `output_dir`
set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (Optional[`transformers.DataCollator`]):
data_collator (`Optional[transformers.DataCollator]`):
The data collator to use for training.
train_dataset (Optional[`datasets.Dataset`]):
train_dataset (`Optional[datasets.Dataset]`):
The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]):
The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
tokenizer (Optional[`transformers.PreTrainedTokenizer`]):
tokenizer (`Optional[transformers.PreTrainedTokenizer]`):
The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
Expand Down
Loading