diff --git a/example/llm-finetune/models/baichuan2/finetune.py b/example/llm-finetune/models/baichuan2/finetune.py index c6b9ff81ef..01f9d10353 100644 --- a/example/llm-finetune/models/baichuan2/finetune.py +++ b/example/llm-finetune/models/baichuan2/finetune.py @@ -19,6 +19,7 @@ PreTrainedTokenizer, AutoModelForCausalLM, ) +from torch.utils.data import ChainDataset from transformers.training_args import TrainingArguments from starwhale import Dataset, finetune @@ -86,9 +87,6 @@ def __call__(self, example: t.Dict) -> t.Dict: model_modules=[copilot_predict, "finetune:lora_finetune"], ) def lora_finetune(train_datasets: t.List[Dataset]) -> None: - # TODO: support multi train datasets - train_dataset = train_datasets[0] - model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_DIR, trust_remote_code=True, @@ -167,16 +165,22 @@ def lora_finetune(train_datasets: t.List[Dataset]) -> None: ) # TODO: support deepspeed + train_dataset = ChainDataset( + [ + ds.to_pytorch( + transform=DataCollatorForCausalLM( + tokenizer=tokenizer, source_max_len=16, target_max_len=512 + ) + ) + for ds in train_datasets + ] + ) trainer = Trainer( model=model, tokenizer=tokenizer, args=train_args, - train_dataset=train_dataset.to_pytorch( - transform=DataCollatorForCausalLM( - tokenizer=tokenizer, source_max_len=16, target_max_len=512 - ) - ), + train_dataset=train_dataset, ) print("Starting model training...") diff --git a/example/llm-finetune/models/chatglm3/finetune.py b/example/llm-finetune/models/chatglm3/finetune.py index bf68378f53..452846f4db 100644 --- a/example/llm-finetune/models/chatglm3/finetune.py +++ b/example/llm-finetune/models/chatglm3/finetune.py @@ -15,6 +15,7 @@ DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, ) +from torch.utils.data import ChainDataset from transformers.modeling_utils import unwrap_model, PreTrainedModel from starwhale import Dataset, finetune @@ -33,9 +34,6 @@ model_modules=["evaluation", "finetune"], ) def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None: - # TODO: support multi train datasets - train_dataset = train_datasets[0] - config = AutoConfig.from_pretrained( BASE_MODEL_DIR, trust_remote_code=True, @@ -65,6 +63,18 @@ def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None: model.gradient_checkpointing_enable() model.enable_input_require_grads() + train_dataset = ChainDataset( + [ + ds.to_pytorch( + transform=MultiTurnDataTransform( + tokenizer=tokenizer, + max_seq_len=int(os.environ.get("MAX_SEQ_LEN", 2048)), + ) + ) + for ds in train_datasets + ] + ) + trainer = PrefixTrainer( model=model, tokenizer=tokenizer, @@ -81,12 +91,7 @@ def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None: gradient_checkpointing=False, remove_unused_columns=False, ), - train_dataset=train_dataset.to_pytorch( - transform=MultiTurnDataTransform( - tokenizer=tokenizer, - max_seq_len=int(os.environ.get("MAX_SEQ_LEN", 2048)), - ) - ), + train_dataset=train_dataset, data_collator=DataCollatorForSeq2Seq( tokenizer=tokenizer, model=model,