To run this code, several libraries need to be installed including:
trl
(for SFTTrainer)transformers
(for models and tokenizers)accelerate
(for training acceleration)peft
(for efficient fine-tuning)datasets
(for loading datasets)bitsandbytes
(for model quantization)einops
(required for loading Falcon models)wandb
(for experiment tracking)
You can install these libraries using pip:
pip install trl transformers accelerate datasets bitsandbytes einops wandb
pip install git+https://github.com/huggingface/peft.git
The code uses the Guanaco dataset, a clean subset of the OpenAssistant dataset adapted to train general purpose chatbots. The dataset can be found here.
The model being fine-tuned is the Falcon 7B model. This model is quantized to 4 bits to save memory and then attached to LoRA adapters.
The code uses SFTTrainer
from the TRL library, which provides a wrapper around the transformers Trainer
to easily fine-tune models on instruction based datasets using PEFT adapters.
The model is trained with the following parameters:
- Output directory:
./results
- Per device training batch size: 4
- Gradient accumulation steps: 4
- Optimizer:
paged_adamw_32bit
- Save steps: 10
- Logging steps: 10
- Learning rate: 2e-4
- Max grad norm: 0.3
- Max steps: 500
- Warmup ratio: 0.03
- LR scheduler type:
constant
After training, the model should converge nicely. The SFTTrainer
also takes care of properly saving only the adapters during training instead of saving the entire model.
To use the code, simply run it in a Google Colab environment after installing the required libraries and importing the Guanaco dataset. The code will handle the rest, from loading the model and fine-tuning it to saving the results.
The code uses Google Colab for the fine-tuning process. It may not work as expected if run in a different environment.
Also, remember that fine-tuning large models like Falcon 7B requires significant computational resources. Please ensure your environment has sufficient resources to avoid any issues during the fine-tuning process.