diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4119e547a37616..880de41a75d785 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -257,6 +257,23 @@ def _is_peft_model(model): return False +def _has_peft_submodule(model): + if _is_peft_model(model): + return True + elif is_peft_available(): + classes_to_check = (PeftModel,) if is_peft_available() else () + # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 + if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): + from peft import PeftMixedModel + + classes_to_check = (*classes_to_check, PeftMixedModel) + + for submodule in model.modules(): + if isinstance(submodule, classes_to_check): + return True + return False + + def _get_fsdp_ckpt_kwargs(): # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): @@ -477,7 +494,7 @@ def __init__( ) # At this stage the model is already loaded - if _is_quantized_and_base_model and not _is_peft_model(model): + if _is_quantized_and_base_model and not _has_peft_submodule(model): raise ValueError( "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of" " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 26fa4624674ec5..2fae1f624e18e9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -27,7 +27,7 @@ from functools import partial from itertools import product from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union from unittest.mock import Mock, patch import numpy as np @@ -117,6 +117,7 @@ GlueDataTrainingArguments, GPT2Config, GPT2LMHeadModel, + Idefics2ForConditionalGeneration, LineByLineTextDataset, LlamaConfig, LlamaForCausalLM, @@ -996,6 +997,49 @@ def test_bnb_compile(self): with self.assertRaises(ValueError): _ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa + @require_peft + @require_bitsandbytes + def test_peft_submodule(self): + from peft import LoraConfig, PeftModel, get_peft_model + + # Due to the way the Trainer is implemented we must be able to save the model with 'save_pretrained' + # Therefore to use peft submodules you must specify your own 'save_pretrained' method + # This example subclass allows for the saving of any and all submodules that are of type PeftModel + class PeftSubmoduleIdefics2ForConditionalGeneration(Idefics2ForConditionalGeneration): + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + for name, submodule in self.named_modules(): + if isinstance(submodule, PeftModel): + submodule.save_pretrained(os.path.join(save_directory, name), **kwargs) + + # Simply tests if initializing a Trainer with a PEFT on a submodule will pass _has_peft_model check. + # Should be recognised as being trainable from the peft submodule and not throw an error when quantised. + tiny_model = PeftSubmoduleIdefics2ForConditionalGeneration.from_pretrained( + "trl-internal-testing/tiny-random-idefics2", load_in_4bit=True + ) + + peft_config = LoraConfig( + r=8, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + ) + tiny_model.vision_model = get_peft_model(tiny_model.model.vision_model, peft_config) + + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments(tmp_dir, learning_rate=1e-9, logging_steps=5) + trainer = Trainer(tiny_model, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + @require_peft def test_multiple_peft_adapters(self): from peft import LoraConfig, get_peft_model