-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added QLoRA fine-tuning pipeline (#40)
- Loading branch information
Showing
11 changed files
with
242 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,3 +148,6 @@ _html/ | |
|
||
# mlflow output | ||
mlruns/ | ||
|
||
#tensorflow output | ||
results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json | ||
command: > | ||
python -m autora.doc.pipelines.main train | ||
${{inputs.new_model_name}} | ||
${{inputs.data_dir}}/data.jsonl | ||
--base-model ${{inputs.model_path}} | ||
code: ../src | ||
inputs: | ||
data_dir: | ||
type: uri_folder | ||
path: azureml://datastores/workspaceblobstore/paths/data/autora | ||
model_path: autora-doc/Llama-2-7b-chat-hf-nf4 | ||
new_model_name: autora-doc/Llama-2-7b-chat-hf-nf4-ft | ||
environment_variables: | ||
PYTORCH_CUDA_ALLOC_CONF: max_split_size_mb:128 | ||
# using a curated environment doesn't work because we need additional packages | ||
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21 | ||
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21 | ||
# These didn't work | ||
# image: mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu117-py38-torch201:biweekly.202310.3 | ||
# image: mcr.microsoft.com/azureml/curated/acpt-pytorch-1.13-cuda11.7:latest | ||
# image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.0.3-cudnn8-ubuntu18.04 | ||
# image: mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04 | ||
# image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.6-cudnn8-ubuntu20.04 | ||
# image: nvcr.io/nvidia/pytorch:23.10-py3 | ||
conda_file: conda.yml | ||
display_name: autodoc_train | ||
compute: azureml:v100cluster | ||
experiment_name: train | ||
description: | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Iterable, List, Tuple | ||
|
||
|
||
def load_data(data_file: str) -> Tuple[List[str], List[str]]: | ||
import jsonlines | ||
|
||
with jsonlines.open(data_file) as reader: | ||
items = [item for item in reader] | ||
inputs = [item["instruction"] for item in items] | ||
labels = [item["output"] for item in items] | ||
return inputs, labels | ||
|
||
|
||
def preprocess_code(code: str) -> str: | ||
lines: Iterable[str] = code.splitlines() | ||
skip_starts = {"import", "from", "#"} | ||
lines = filter( | ||
lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""), | ||
lines, | ||
) | ||
return "\n".join(lines) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import Dict, Iterable | ||
|
||
import torch | ||
from datasets import Dataset | ||
from peft import LoraConfig | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments | ||
from trl import SFTTrainer | ||
|
||
from autora.doc.pipelines.data import load_data, preprocess_code | ||
from autora.doc.runtime.predict_hf import get_quantization_config | ||
from autora.doc.runtime.prompts import INSTR_SWEETP_1, SYS_GUIDES, PromptBuilder | ||
|
||
|
||
def get_dataset(jsonl_path: str) -> Dataset: | ||
# "instruction", "output" | ||
inputs, labels = load_data(jsonl_path) | ||
|
||
def gen() -> Iterable[Dict[str, str]]: | ||
for i, o in zip(inputs, labels): | ||
text = PromptBuilder(SYS_GUIDES, INSTR_SWEETP_1).add_example(preprocess_code(i), o).build() | ||
yield {"text": text} | ||
|
||
ds = Dataset.from_generator(gen) | ||
return ds | ||
|
||
|
||
def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None: | ||
cuda_available = torch.cuda.is_available() | ||
config = {} | ||
|
||
# train using 4 bit quantization for lower GPU memory usage | ||
if cuda_available: | ||
config.update({"device_map": "auto", "quantization_config": get_quantization_config()}) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
base_model, | ||
**config, | ||
) | ||
model.config.use_cache = False | ||
model.config.pretraining_tp = 1 | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.padding_side = "right" | ||
|
||
peft_params = LoraConfig( | ||
lora_alpha=16, | ||
lora_dropout=0.05, | ||
r=8, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
# All of these parameters are initial defaults and may need further tuning | ||
training_params = TrainingArguments( | ||
output_dir="./results", | ||
num_train_epochs=4, | ||
per_device_train_batch_size=1, # TODO: Increase once there's more data | ||
gradient_accumulation_steps=1, | ||
optim="paged_adamw_32bit" if cuda_available else "adamw_torch", | ||
save_steps=25, | ||
logging_steps=1, # TODO: Increase once there's more data | ||
learning_rate=2e-4, | ||
weight_decay=0.001, | ||
fp16=cuda_available, | ||
bf16=False, | ||
max_grad_norm=0.3, | ||
max_steps=-1, | ||
warmup_ratio=0.03, | ||
group_by_length=True, | ||
lr_scheduler_type="constant", | ||
report_to="tensorboard", | ||
) | ||
|
||
# Use a Supervised Fine-Tuning Trainer | ||
trainer = SFTTrainer( | ||
model=model, | ||
train_dataset=dataset, | ||
peft_config=peft_params, | ||
dataset_text_field="text", | ||
max_seq_length=1024, | ||
tokenizer=tokenizer, | ||
args=training_params, | ||
packing=False, | ||
) | ||
|
||
trainer.train() | ||
trainer.model.save_pretrained(new_model_name) | ||
trainer.tokenizer.save_pretrained(new_model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.