Skip to content

Commit

Permalink
example: update baichuan2/chatglm3 example to support multi train dat…
Browse files Browse the repository at this point in the history
…asets (#3075)
  • Loading branch information
tianweidut authored Dec 12, 2023
1 parent bca9510 commit a87d70e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
20 changes: 12 additions & 8 deletions example/llm-finetune/models/baichuan2/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PreTrainedTokenizer,
AutoModelForCausalLM,
)
from torch.utils.data import ChainDataset
from transformers.training_args import TrainingArguments

from starwhale import Dataset, finetune
Expand Down Expand Up @@ -86,9 +87,6 @@ def __call__(self, example: t.Dict) -> t.Dict:
model_modules=[copilot_predict, "finetune:lora_finetune"],
)
def lora_finetune(train_datasets: t.List[Dataset]) -> None:
# TODO: support multi train datasets
train_dataset = train_datasets[0]

model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_DIR,
trust_remote_code=True,
Expand Down Expand Up @@ -167,16 +165,22 @@ def lora_finetune(train_datasets: t.List[Dataset]) -> None:
)

# TODO: support deepspeed
train_dataset = ChainDataset(
[
ds.to_pytorch(
transform=DataCollatorForCausalLM(
tokenizer=tokenizer, source_max_len=16, target_max_len=512
)
)
for ds in train_datasets
]
)

trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=train_args,
train_dataset=train_dataset.to_pytorch(
transform=DataCollatorForCausalLM(
tokenizer=tokenizer, source_max_len=16, target_max_len=512
)
),
train_dataset=train_dataset,
)

print("Starting model training...")
Expand Down
23 changes: 14 additions & 9 deletions example/llm-finetune/models/chatglm3/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
)
from torch.utils.data import ChainDataset
from transformers.modeling_utils import unwrap_model, PreTrainedModel

from starwhale import Dataset, finetune
Expand All @@ -33,9 +34,6 @@
model_modules=["evaluation", "finetune"],
)
def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None:
# TODO: support multi train datasets
train_dataset = train_datasets[0]

config = AutoConfig.from_pretrained(
BASE_MODEL_DIR,
trust_remote_code=True,
Expand Down Expand Up @@ -65,6 +63,18 @@ def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

train_dataset = ChainDataset(
[
ds.to_pytorch(
transform=MultiTurnDataTransform(
tokenizer=tokenizer,
max_seq_len=int(os.environ.get("MAX_SEQ_LEN", 2048)),
)
)
for ds in train_datasets
]
)

trainer = PrefixTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -81,12 +91,7 @@ def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None:
gradient_checkpointing=False,
remove_unused_columns=False,
),
train_dataset=train_dataset.to_pytorch(
transform=MultiTurnDataTransform(
tokenizer=tokenizer,
max_seq_len=int(os.environ.get("MAX_SEQ_LEN", 2048)),
)
),
train_dataset=train_dataset,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
Expand Down

0 comments on commit a87d70e

Please sign in to comment.