From f5f7b67d606ee8b74bf42f090065722585b5c15b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 29 Apr 2024 13:09:34 +0200 Subject: [PATCH 1/4] FIX Issues with AdaLora initialization (#1652) Resolves #1647 - AdaLoraConfig now converts target_modules to set, same as LoRA - AdaLoraConfig now raises when used with DoRA - AdaLoraConfig now raises when used with LoftQ - AdaLoraModel now raises when trying to call add_weighted_adapter - Add tests for those in test_initialization.py - Small clean ups in test_initialization.py --- src/peft/tuners/adalora/config.py | 17 +++++++++++ src/peft/tuners/adalora/model.py | 4 +++ tests/test_initialization.py | 48 ++++++++++++++++++++++++++----- 3 files changed, 62 insertions(+), 7 deletions(-) diff --git a/src/peft/tuners/adalora/config.py b/src/peft/tuners/adalora/config.py index 93905ff28b..7972e81dd3 100644 --- a/src/peft/tuners/adalora/config.py +++ b/src/peft/tuners/adalora/config.py @@ -50,3 +50,20 @@ class AdaLoraConfig(LoraConfig): def __post_init__(self): self.peft_type = PeftType.ADALORA + + if self.use_dora: + raise ValueError(f"{self.peft_type} does not support DoRA.") + + if self.loftq_config: + raise ValueError(f"{self.peft_type} does not support LOFTQ.") + + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index c31b422340..2a1302979a 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -349,3 +349,7 @@ def update_and_allocate(self, global_step): # Pass the function and do forward propagation else: return None + + def add_weighted_adapter(self, *args, **kwargs): + """This method is not supported for AdaLoRA, use LoRA instead.""" + raise TypeError(f"{self.__class__.__name__} does not support add_weighted_adapter method.") diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 44ba12fbc7..8e0e091b56 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -19,11 +19,11 @@ from scipy import stats from torch import nn -from peft import LoraConfig, PromptTuningConfig, VeraConfig, get_peft_model +from peft import AdaLoraConfig, LoraConfig, PromptTuningConfig, VeraConfig, get_peft_model from peft.utils import infer_device -class TestInitialization: +class TestLoraInitialization: """Test class to check the initialization of adapters.""" torch_device = infer_device() @@ -253,7 +253,7 @@ def test_lora_scaling_default(self): assert model.embed.scaling["default"] == expected_scaling assert model.conv2d.scaling["default"] == expected_scaling - def test_rslora_scaling(self): + def test_lora_rslora_scaling(self): # default is True torch.manual_seed(0) @@ -296,7 +296,7 @@ def test_lora_default_scaling_pattern(self): assert model.embed.scaling["default"] == expected_scaling["embed"] assert model.conv2d.scaling["default"] == expected_scaling["conv2d"] - def test_rslora_scaling_pattern(self): + def test_lora_rslora_scaling_pattern(self): # default is True torch.manual_seed(0) @@ -323,7 +323,7 @@ def test_rslora_scaling_pattern(self): assert model.embed.scaling["default"] == expected_scaling["embed"] assert model.conv2d.scaling["default"] == expected_scaling["conv2d"] - def test_use_dora_linear(self, data): + def test_lora_use_dora_linear(self, data): # check that dora is a no-op when initialized torch.manual_seed(0) model = self.get_model() @@ -340,7 +340,7 @@ def test_use_dora_linear(self, data): assert torch.allclose(output_base, output_disabled) assert torch.allclose(output_base, output_dora) - def test_use_dora_linear_init_false(self, data): + def test_lora_use_dora_linear_init_false(self, data): # with init_lora_weights=False, dora should not be a no-op torch.manual_seed(0) model = self.get_model() @@ -357,11 +357,45 @@ def test_use_dora_linear_init_false(self, data): assert torch.allclose(output_base, output_disabled) assert not torch.allclose(output_base, output_dora) - def test_use_dora_with_megatron_core_raises(self): + def test_lora_use_dora_with_megatron_core_raises(self): megatron_config = {"does-not": "matter-here"} with pytest.raises(ValueError, match="DoRA does not support megatron_core"): LoraConfig(target_modules=["linear"], use_dora=True, megatron_config=megatron_config) + +class TestAdaLoraInitialization: + def test_adalora_target_modules_set(self): + config = AdaLoraConfig(target_modules=["linear", "embed", "conv2d"]) + assert config.target_modules == {"linear", "embed", "conv2d"} + + def test_adalora_use_dora_raises(self): + with pytest.raises(ValueError, match="ADALORA does not support DoRA"): + AdaLoraConfig(use_dora=True) + + def test_adalora_loftq_config_raises(self): + with pytest.raises(ValueError, match="ADALORA does not support LOFTQ"): + AdaLoraConfig(loftq_config={"loftq": "config"}) + + +class TestPromptTuningInitialization: + torch_device = infer_device() + + def get_model(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + # choose a large weight so that averages are close to expected values + self.linear = nn.Linear(1000, 1000) + self.embed = nn.Embedding(1000, 1000) + self.conv2d = nn.Conv2d(100, 100, 3) + + def forward(self, x): + x_int = (100 * x).int() + x_4d = x.flatten().reshape(1, 100, 10, 10) + return self.linear(x), self.embed(x_int), self.conv2d(x_4d) + + return MyModule().eval().to(self.torch_device) + def test_use_prompt_tuning_init_text_raises(self): with pytest.raises(ValueError, match="When prompt_tuning_init='TEXT', tokenizer_name_or_path can't be None"): PromptTuningConfig(prompt_tuning_init="TEXT", prompt_tuning_init_text="prompt tuning init text") From 250b7eb85fa00d64d3e6f7c3829b803341c1c966 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 29 Apr 2024 13:31:23 +0200 Subject: [PATCH 2/4] FEAT Show adapter layer and model status (#1663) This PR adds a new feature to PEFT models that allows to better understand the status of adapter(s) on the model. Quoting from the doc entry that I added: Sometimes, the PEFT model can end up in a bad state, especially when handling multiple adapters. There can be some confusion around what adapters exist, which one is active, which one is merged, etc. To help investigate this issue, you can call the get_layer_status and the get_model_status methods. The first one gives you a detailed overview about the adapters for each targeted layer. The latter one gives you a high-level overview about the model status. --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- .../developer_guides/troubleshooting.md | 105 ++++ docs/source/package_reference/peft_model.md | 4 + src/peft/__init__.py | 2 + src/peft/mixed_model.py | 6 + src/peft/peft_model.py | 365 +++++++++++- src/peft/tuners/tuners_utils.py | 30 + tests/test_custom_models.py | 60 ++ tests/test_tuners_utils.py | 555 +++++++++++++++++- 8 files changed, 1116 insertions(+), 11 deletions(-) diff --git a/docs/source/developer_guides/troubleshooting.md b/docs/source/developer_guides/troubleshooting.md index 4cee19c48e..1780e5884b 100644 --- a/docs/source/developer_guides/troubleshooting.md +++ b/docs/source/developer_guides/troubleshooting.md @@ -135,3 +135,108 @@ model.save_adapter("my_adapter", save_embedding_layers=True) For inference, load the base model first and resize it the same way you did before you trained the model. After you've resized the base model, you can load the PEFT checkpoint. For a complete example, please check out [this notebook](https://github.com/huggingface/peft/blob/main/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb). + +### Check layer and model status + +Sometimes a PEFT model can end up in a bad state, especially when handling multiple adapters. There can be some confusion around what adapters exist, which one is active, which one is merged, etc. To help investigate this issue, call the [`~peft.PeftModel.get_layer_status`] and the [`~peft.PeftModel.get_model_status`] methods. + +The [`~peft.PeftModel.get_layer_status`] method gives you a detailed overview of each targeted layer's active, merged, and available adapters. + +```python +>>> from transformers import AutoModel +>>> from peft import get_peft_model, LoraConfig + +>>> model_id = "google/flan-t5-small" +>>> model = AutoModel.from_pretrained(model_id) +>>> model = get_peft_model(model, LoraConfig()) + +>>> model.get_layer_status() +[TunerLayerStatus(name='model.encoder.block.0.layer.0.SelfAttention.q', + module_type='lora.Linear', + enabled=True, + active_adapters=['default'], + merged_adapters=[], + requires_grad={'default': True}, + available_adapters=['default']), + TunerLayerStatus(name='model.encoder.block.0.layer.0.SelfAttention.v', + module_type='lora.Linear', + enabled=True, + active_adapters=['default'], + merged_adapters=[], + requires_grad={'default': True}, + available_adapters=['default']), +...] + +>>> model.get_model_status() +TunerModelStatus( + base_model_type='T5Model', + adapter_model_type='LoraModel', + peft_types={'default': 'LORA'}, + trainable_params=344064, + total_params=60855680, + num_adapter_layers=48, + enabled=True, + active_adapters=['default'], + merged_adapters=[], + requires_grad={'default': True}, + available_adapters=['default'], +) +``` + +In the model state output, you should look out for entries that say `"irregular"`. This means PEFT detected an inconsistent state in the model. For instance, if `merged_adapters="irregular"`, it means that for at least one adapter, it was merged on some target modules but not on others. The inference results will most likely be incorrect as a result. + +The best way to resolve this issue is to reload the whole model and adapter checkpoint(s). Ensure that you don't perform any incorrect operations on the model, e.g. manually merging adapters on some modules but not others. + +Convert the layer status into a pandas `DataFrame` for an easier visual inspection. + +```python +from dataclasses import asdict +import pandas as pd + +df = pd.DataFrame(asdict(layer) for layer in model.get_layer_status()) +``` + +It is possible to get this information for non-PEFT models if they are using PEFT layers under the hood, but some information like the `base_model_type` or the `peft_types` cannot be determined in that case. As an example, you can call this on a [diffusers](https://huggingface.co/docs/diffusers/index) model like so: + +```python +>>> import torch +>>> from diffusers import StableDiffusionPipeline +>>> from peft import get_model_status, get_layer_status + +>>> path = "runwayml/stable-diffusion-v1-5" +>>> lora_id = "takuma104/lora-test-text-encoder-lora-target" +>>> pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) +>>> pipe.load_lora_weights(lora_id, adapter_name="adapter-1") +>>> pipe.load_lora_weights(lora_id, adapter_name="adapter-2") +>>> get_layer_status(pipe.text_encoder) +[TunerLayerStatus(name='text_model.encoder.layers.0.self_attn.k_proj', + module_type='lora.Linear', + enabled=True, + active_adapters=['adapter-2'], + merged_adapters=[], + requires_grad={'adapter-1': False, 'adapter-2': True}, + available_adapters=['adapter-1', 'adapter-2']), + TunerLayerStatus(name='text_model.encoder.layers.0.self_attn.v_proj', + module_type='lora.Linear', + enabled=True, + active_adapters=['adapter-2'], + merged_adapters=[], + requires_grad={'adapter-1': False, 'adapter-2': True}, + available_adapters=['adapter-1', 'adapter-2']), +...] + +>>> get_model_status(pipe.unet) +TunerModelStatus( + base_model_type='other', + adapter_model_type='None', + peft_types={}, + trainable_params=797184, + total_params=861115332, + num_adapter_layers=128, + enabled=True, + active_adapters=['adapter-2'], + merged_adapters=[], + requires_grad={'adapter-1': False, 'adapter-2': True}, + available_adapters=['adapter-1', 'adapter-2'], +) +``` diff --git a/docs/source/package_reference/peft_model.md b/docs/source/package_reference/peft_model.md index 9313ba75e9..366ef91fd8 100644 --- a/docs/source/package_reference/peft_model.md +++ b/docs/source/package_reference/peft_model.md @@ -71,3 +71,7 @@ A `PeftModel` for mixing different adapter types (e.g. LoRA and LoHa). [[autodoc]] utils.get_peft_model_state_dict [[autodoc]] utils.prepare_model_for_kbit_training + +[[autodoc]] get_layer_status + +[[autodoc]] get_model_status diff --git a/src/peft/__init__.py b/src/peft/__init__.py index c2f0154419..c9a7675f7b 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -44,6 +44,8 @@ PeftModelForTokenClassification, PeftModelForQuestionAnswering, PeftModelForFeatureExtraction, + get_layer_status, + get_model_status, ) from .tuners import ( AdaptionPromptConfig, diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 92b9f74ecd..be640e01e7 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -311,6 +311,12 @@ def unload(self, *args: Any, **kwargs: Any): """ return self.base_model.unload(*args, **kwargs) + def get_layer_status(self): + raise TypeError(f"get_layer_status is not supported for {self.__class__.__name__}.") + + def get_model_status(self): + raise TypeError(f"get_model_status is not supported for {self.__class__.__name__}.") + @classmethod def _split_kwargs(cls, kwargs: dict[str, Any]): return PeftModel._split_kwargs(kwargs) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 922fee5d30..d89bd3f220 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -20,7 +20,8 @@ import warnings from contextlib import contextmanager from copy import deepcopy -from typing import Any, Optional, Union +from dataclasses import dataclass +from typing import Any, Literal, Optional, Union import packaging.version import torch @@ -54,7 +55,7 @@ PromptEncoder, VeraModel, ) -from .tuners.tuners_utils import BaseTunerLayer +from .tuners.tuners_utils import BaseTuner, BaseTunerLayer from .utils import ( SAFETENSORS_WEIGHTS_NAME, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, @@ -628,24 +629,42 @@ def disable_adapter(self): ... model(inputs) ``` """ - try: - if self.peft_config[self.active_adapter].is_prompt_learning: + if self.peft_config[self.active_adapter].is_prompt_learning: + try: # TODO: consider replacing this patching of methods with a more robust mechanism: setting a flag and # letting the underlying methods deal with it, same as how LoRA does it. old_forward = self.forward self.forward = self.base_model.forward old_prepare_inputs_for_generation = self.prepare_inputs_for_generation self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation - else: - self.base_model.disable_adapter_layers() - yield - finally: - if self.peft_config[self.active_adapter].is_prompt_learning: + yield + finally: self.forward = old_forward self.prepare_inputs_for_generation = old_prepare_inputs_for_generation - else: + + elif self.peft_config[self.active_adapter].is_adaption_prompt: + try: + self.base_model.disable_adapter_layers() + yield + finally: self.base_model.enable_adapter_layers() + else: # LoRA, LoHa, etc. + model_status = self.get_model_status() + if model_status.enabled == "irregular": + warnings.warn( + "The model contains some adapter layers that are enabled and others that are disabled. " + "This is most likely unintentional. After exiting the disable_adapter context, all adapters " + "will be enabled" + ) + try: + self.base_model.disable_adapter_layers() + yield + finally: + if model_status.enabled is not False: + # model_status.enabled is `True` or `"irregular"` + self.base_model.enable_adapter_layers() + def get_base_model(self) -> torch.nn.Module: """ Returns the base model. @@ -709,6 +728,76 @@ def set_additional_trainable_modules(self, peft_config, adapter_name): self.modules_to_save.update(peft_config.modules_to_save) _set_trainable(self, adapter_name) # this may add a new ModulesToSaveWrapper + def get_layer_status(self) -> list[TunerLayerStatus]: + """Get the status of each adapter layer in the model. + + This method returns a list of `TunerLayerStatus` dataclass instances, each of which contains the following + attributes: + + - `name` (`str`): + The name of the adapter layer, e.g. `model.encoder.block.0.layer.0.SelfAttention.q`. + - `module_type` (`str`): + The type of the adapter layer, e.g. `lora.Linear`. + - `enabled` (`bool`): + Whether the adapter layer is enabled. + - `active_adapters` (`list[str]`): + The names of the active adapters, if any, e.g. `["default"]`. + - `merged_adapters` (`list[str]`): + The names of the merged adapters, if any, e.g. `["default"]`. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([`~PeftModel`]): + The model to get the adapter layer status from. + + Returns: + list[`peft.peft_model.TunerLayerStatus`]: + A list of dataclasses, each containing the status of the corresponding adapter layer. + + """ + return get_layer_status(self) + + def get_model_status(self) -> TunerModelStatus: + """Get the status of tuners of the model. + + This method returns a `TunerModelStatus` dataclass instance, which contains the following attributes: + + - `base_model_type` (`str`): + The type of the base model, e.g. `T5Model`. + - `adapter_model_type` (`str`): + The type of the adapter model, e.g. `LoraModel`. + - `peft_types` (`dict[str, str]`): + The mapping of adapter name to adapter type, e.g. `{"default": "LORA"}`. + - `trainable_params` (`int`): + The number of trainable parameters in the model. + - `total_params` (`int`): + The total number of parameters in the model. + - `num_adapter_layers` (`int`): + The number of adapter layers in the model. + - `enabled` (`bool`, `Literal["irregular"]`): + Whether all adapter layers are enabled. If some are enabled and some are not, this will be `"irregular"`. + This means that your model is in an inconsistent state and might not work as expected. + - `active_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the active adapters. If the active adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `merged_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([`~PeftModel`]): + The model to get the adapter layer status from. + + Returns: + `peft.peft_model.TunerModelStatus`: + A dataclass containing the status of the model. + + """ + return get_model_status(self) + @classmethod def _split_kwargs(cls, kwargs: dict[str, Any]): _kwargs_not_in_hf_hub_download_signature = ("use_auth_token",) @@ -2229,3 +2318,259 @@ def forward( prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + +@dataclass +class TunerLayerStatus: + name: str + module_type: str + enabled: bool + active_adapters: list[str] + merged_adapters: list[str] + requires_grad: dict[str, bool | Literal["irregular"]] + available_adapters: list[str] + + +def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]: + """Get the status of each adapter layer in the model. + + This function returns a list of `TunerLayerStatus` dataclass instances, each of which contains the following + attributes: + + - `name` (`str`): + The name of the adapter layer, e.g. `model.encoder.block.0.layer.0.SelfAttention.q`. + - `module_type` (`str`): + The type of the adapter layer, e.g. `lora.Linear`. + - `enabled` (`bool`): + Whether the adapter layer is enabled. + - `active_adapters` (`list[str]`): + The names of the active adapters, if any, e.g. `["default"]`. + - `merged_adapters` (`list[str]`): + The names of the merged adapters, if any, e.g. `["default"]`. + - requires_grad : dict[str, bool | Literal["irregular"]] + The requires_grad status of the parameters for each adapter module. Ideally, it should be either `True` or + `False`. If the requires_grad status is not consistent across all parameters, the value will be set to + `"irregular"`. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]): + The model to get the adapter layer status from. + + Returns: + list[`peft.peft_model.TunerLayerStatus`]: + A list of dataclasses, each containing the status of the corresponding adapter layer. + + """ + if isinstance(model, PeftModel): + base_model = model.base_model + if not isinstance(base_model, BaseTuner): + raise TypeError( + "get_layer_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not " + "supported." + ) + else: + base_model = model + + layer_status: list[TunerLayerStatus] = [] + for name, module in base_model.named_modules(): + if not isinstance(module, BaseTunerLayer): + continue + + # determine if all submodules/parameters if this module require grad or not + mapping_requires_grad_list: dict[str, list[bool]] = collections.defaultdict(list) + for adapter_module_name in module.adapter_layer_names: + adapter_module = getattr(module, adapter_module_name) + if isinstance(adapter_module, torch.nn.ModuleDict): + for key, submodule in adapter_module.items(): + for param in submodule.parameters(): + mapping_requires_grad_list[key].append(param.requires_grad) + elif isinstance(adapter_module, torch.nn.ParameterDict): + for key, param in adapter_module.items(): + mapping_requires_grad_list[key].append(param.requires_grad) + else: + # strange, we don't know how to handle this, ignore for now + pass + + def check_irrgular(vals: list[bool]) -> bool | Literal["irregular"]: + if all(vals): + return True + if not any(vals): + return False + return "irregular" + + requires_grad = {key: check_irrgular(vals) for key, vals in mapping_requires_grad_list.items()} + + status = TunerLayerStatus( + name=name, + module_type=repr(module).partition("(")[0], + enabled=not module.disable_adapters, + active_adapters=module.active_adapters, + merged_adapters=module.merged_adapters, + requires_grad=requires_grad, + available_adapters=sorted(module._get_available_adapters()), + ) + layer_status.append(status) + + if not layer_status: + raise ValueError( + "No adapter layers found in the model, please ensure that it's a PEFT model or that you have PEFT adapters " + "injected in the model." + ) + + return layer_status + + +@dataclass +class TunerModelStatus: + base_model_type: str + adapter_model_type: str + peft_types: dict[str, str] + trainable_params: int + total_params: int + num_adapter_layers: int + enabled: bool | Literal["irregular"] + active_adapters: list[str] | Literal["irregular"] + merged_adapters: list[str] | Literal["irregular"] + requires_grad: dict[str, bool | Literal["irregular"]] + available_adapters: list[str] + + +def get_model_status(model: torch.nn.Module) -> TunerModelStatus: + """Get the status of tuners of the model. + + This function returns a `TunerModelStatus` dataclass instance, which contains the following attributes: + + - `base_model_type` (`str`): + The type of the base model, e.g. `T5Model`. + - `adapter_model_type` (`str`): + The type of the adapter model, e.g. `LoraModel`. + - `peft_types` (`dict[str, str]`): + The mapping of adapter name to adapter type, e.g. `{"default": "LORA"}`. + - `trainable_params` (`int`): + The number of trainable parameters in the model. + - `total_params` (`int`): + The total number of parameters in the model. + - `num_adapter_layers` (`int`): + The number of adapter layers in the model. + - `enabled` (`bool`, `Literal["irregular"]`): + Whether all adapter layers are enabled. If some are enabled and some are not, this will be `"irregular"`. This + means that your model is in an inconsistent state and might not work as expected. + - `active_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the active adapters. If the active adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `merged_adapters` (`list[str]`, `Literal["irregular"]`): + The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be + `"irregular"`, which means that your model is in an inconsistent state and might not work as expected. + - `requires_grad` (`dict[str, bool | Literal["irregular"]]`): + Whether for the given adapter, all adapter layers have `requires_grad` set to `True` or `False`. If there is a + mix, this will be set to `"irregular"`, which means that your model is in an inconsistent state and might not + work as expected. + - `available_adapters` (`list[str]`): + The names of the available adapters, e.g. `["default"]`. + + Args: + model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]): + The model to get the adapter layer status from. + + Returns: + `peft.peft_model.TunerModelStatus`: + A dataclass containing the status of the model. + + """ + if isinstance(model, PeftModel): + if not isinstance(model.base_model, BaseTuner): + raise TypeError( + "get_model_status() got an invalid PeftModel instance; prefix tuning and adaption prompt are not " + "supported." + ) + base_model_type = model.get_base_model().__class__.__name__ + trainable_params, total_params = model.get_nb_trainable_parameters() + base_model = model.base_model + peft_types = {key: str(config.peft_type).partition(".")[-1] for key, config in base_model.peft_config.items()} + adapter_model_type = base_model.__class__.__name__ + elif isinstance(model, PreTrainedModel): + base_model_type = model.__class__.__name__ + trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model) + base_model = model + peft_types = {} + adapter_model_type = "None" + else: + base_model_type = "other" + trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model) + base_model = model + peft_types = {} + adapter_model_type = "None" + + layer_status = get_layer_status(model) + num_adapter_layers = len(layer_status) + + enabled_set: set[bool] = {status.enabled for status in layer_status} # must be {True}, {False}, or {True, False} + enabled: bool | Literal["irregular"] + if len(enabled_set) == 1: + enabled = enabled_set.pop() + else: + enabled = "irregular" + + available_adapters: list[str] = sorted(set().union(*(status.available_adapters for status in layer_status))) + + # ideally, active adapters should be consistent across all layers of the model, but we cannot guarantee it + all_active_adapters: set[tuple[str, ...]] = {tuple(status.active_adapters) for status in layer_status} + active_adapters: list[str] | Literal["irregular"] + if not all_active_adapters: + active_adapters = [] + elif len(all_active_adapters) == 1: + active_adapters = list(all_active_adapters.pop()) + else: + active_adapters = "irregular" + + # Here we determine what adapters are merged. This is not trivial because multiple adapters can be merged or not at + # the same time. Some layers may only have adapter A, some only adapter B, so it's not as easy as just checking + # which adapters are merged on each layer. + + # First, determine all adapters that are merged on at least on module. + merged_all: set[str] = set() + for status in layer_status: + merged_all.update(status.merged_adapters) + + # Next, check if on any layer, on of these adapters is not merged. + merged_adapters: list[str] | Literal["irregular"] = sorted(merged_all) + for status in layer_status: + unmerged = set(status.available_adapters) - set(status.merged_adapters) + if unmerged & merged_all: + # there is overlap between unmerged adapters and adapters that should be merged + merged_adapters = "irregular" + break + + # check status of requires_grad + # first, merge the values for all layers + requires_grad_all: dict[str, list[bool | Literal["irregular"]]] = collections.defaultdict(list) + for status in layer_status: + for key, val in status.requires_grad.items(): + requires_grad_all[key].append(val) + + # then, check if the values are consistent + def check_irrgular(vals: list[bool | Literal["irregular"]]) -> bool | Literal["irregular"]: + if all(val is True for val in vals): + return True + if all(val is False for val in vals): + return False + return "irregular" + + requires_grad = {key: check_irrgular(vals) for key, vals in requires_grad_all.items()} + + adapter_model_status = TunerModelStatus( + base_model_type=base_model_type, + adapter_model_type=adapter_model_type, + peft_types=peft_types, + trainable_params=trainable_params, + total_params=total_params, + num_adapter_layers=num_adapter_layers, + enabled=enabled, + active_adapters=active_adapters, + merged_adapters=merged_adapters, + requires_grad=requires_grad, + available_adapters=available_adapters, + ) + return adapter_model_status diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 8df46a4706..f4a6ba53dc 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -281,6 +281,20 @@ def _mark_only_adapters_as_trainable(self, model: nn.Module): """ ... + @abstractmethod + def disable_adapter_layers(self) -> None: + """ + Disable all adapters in-place. + """ + ... + + @abstractmethod + def enable_adapter_layers(self) -> None: + """ + Enable all adapters in-place + """ + ... + def _check_new_adapter_config(self, config: PeftConfig) -> None: """ A helper method to check the config when a new adapter is being added. @@ -363,6 +377,10 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): f"Please check the target modules and try again." ) + # It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is + # added, and it targets different layer(s) than the first adapter (which is active), then those different + # layers will be activated, which we don't want. + self.set_adapter(self.active_adapters) self._mark_only_adapters_as_trainable(model) if self.peft_config[adapter_name].inference_mode: @@ -497,6 +515,16 @@ def active_adapter(self) -> str | list[str]: # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method return self._active_adapter + def _get_available_adapters(self) -> set[str]: + """Return all adapter names that can be found on this module.""" + adapters = set() + for layer_name in self.adapter_layer_names: + module = getattr(self, layer_name) + if not isinstance(module, (nn.ModuleDict, nn.ParameterDict)): + continue + adapters.update(set(module.keys())) + return adapters + @property def active_adapters(self): if isinstance(self.active_adapter, str): @@ -713,6 +741,8 @@ def check_adapters_to_merge(module: BaseTunerLayer, adapter_names: Optional[list """ if adapter_names is None: adapter_names = module.active_adapters + if isinstance(adapter_names, str): + raise ValueError(f"adapter_names should be a list of strings, got {adapter_names!r}.") if module.merged: merged_adapters = set(module.merged_adapters) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index d2a007a936..2744ace476 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -995,6 +995,66 @@ def test_active_adapter(self, test_name, model_id, config_cls, config_kwargs): # model.active_adapter would not work, thus we have to check the base_model directly assert model.base_model.active_adapter == ["default", "other"] + @parameterized.expand(TEST_CASES) + def test_disable_adapters_exiting_context_restores_previous_state( + self, test_name, model_id, config_cls, config_kwargs + ): + # Test that when we exit the disable_adapter context, we correctly restore the enabled state of the modules as + # they were before the context. + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + tuner_modules = [module for module in model.modules() if isinstance(module, BaseTunerLayer)] + + # all layers should be enabled + assert all(not module.disable_adapters for module in tuner_modules) + with model.disable_adapter(): + pass + # this should not change after exiting the context + assert all(not module.disable_adapters for module in tuner_modules) + + # now disable all layers + model.disable_adapter_layers() + assert all(module.disable_adapters for module in tuner_modules) + with model.disable_adapter(): + pass + assert all(module.disable_adapters for module in tuner_modules) + + @parameterized.expand(TEST_CASES) + def test_disable_adapters_exiting_context_irregular_state(self, test_name, model_id, config_cls, config_kwargs): + # When we have a model where some adapters are enabled and others are disabled, we should get a warning when + # entering the disable_adapter context because we cannot correctly restore the state of the adapters from + # before the context. After exiting the context, all adapters will be enabled, which is the status quo of how + # we deal with this. + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + tuner_modules = [module for module in model.modules() if isinstance(module, BaseTunerLayer)] + + # now we mix the states, some enabled some not + if len(tuner_modules) < 2: + # next check only works with more than 1 tuner module + return + + # disable a single layer + tuner_modules[0].enable_adapters(False) + # sanity check that we have both enabled and disabled layers + assert {module.disable_adapters for module in tuner_modules} == {True, False} + # check that we get a warning with irregular states + msg = "The model contains some adapter layers that are enabled and others that are disabled" + with self.assertWarnsRegex(UserWarning, expected_regex=msg): + with model.disable_adapter(): + pass + + # when encountering irregular adapters, we enable all adapters at the end of the context + assert all(not module.disable_adapters for module in tuner_modules) + @parameterized.expand(TEST_CASES) def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index c6ce7bdc19..1388fa0694 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -14,17 +14,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import unittest from copy import deepcopy import pytest +import torch from diffusers import StableDiffusionPipeline from parameterized import parameterized from torch import nn from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig -from peft import IA3Config, LoHaConfig, LoraConfig, get_peft_model +from peft import ( + AdaptionPromptConfig, + IA3Config, + LoHaConfig, + LoraConfig, + PromptTuningConfig, + get_layer_status, + get_model_status, + get_peft_model, +) from peft.tuners.tuners_utils import ( + BaseTunerLayer, _maybe_include_all_linear_layers, check_target_module_exists, inspect_matched_modules, @@ -372,3 +384,544 @@ def test_realistic_example(self): f"transformer.h.{i}.self_attention.query_key_value" for i in range(len(model.base_model.transformer.h)) ] assert model.targeted_module_names == expected + + +class TestModelAndLayerStatus: + """Check the methods `get_layer_status` and `get_model_status`.` + + Note that we only test LoRA here but the same logic should work for other tuner types (if they support the + corresponding features like merging). + + """ + + @pytest.fixture + def small_model(self): + class SmallModel(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = nn.Linear(10, 10) + self.lin1 = nn.Linear(10, 10) + + config = LoraConfig(target_modules="lin0") + return get_peft_model(SmallModel(), config) + + @pytest.fixture + def large_model(self): + class LargeModel(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = nn.Linear(10, 10) + self.conv0 = nn.Conv2d(3, 10, 3) + self.emb0 = nn.Embedding(10, 10) + self.lin1 = nn.Linear(10, 10) + self.conv1 = nn.Conv2d(3, 10, 3) + self.emb1 = nn.Embedding(10, 10) + + config0 = LoraConfig(target_modules=["lin0", "conv1", "emb0"]) + config1 = LoraConfig(target_modules=["lin0", "lin1"], r=16) + model = get_peft_model(LargeModel(), config0) + model.add_adapter("other", config1) + return model + + ################ + # layer status # + ################ + + def test_layer_names_small(self, small_model): + layer_status = small_model.get_layer_status() + expected = ["model.lin0"] + assert [status.name for status in layer_status] == expected + + def test_layer_names_large(self, large_model): + layer_status = large_model.get_layer_status() + result = sorted([status.name for status in layer_status]) + expected = ["model.conv1", "model.emb0", "model.lin0", "model.lin1"] + assert result == expected + + def test_module_type_small(self, small_model): + layer_status = small_model.get_layer_status() + assert [status.module_type for status in layer_status] == ["lora.Linear"] + + def test_module_type_large(self, large_model): + layer_status = large_model.get_layer_status() + result = sorted([status.module_type for status in layer_status]) + expected = ["lora.Conv2d", "lora.Embedding", "lora.Linear", "lora.Linear"] + assert result == expected + + def test_enabled_small(self, small_model): + layer_status = small_model.get_layer_status() + assert [status.enabled for status in layer_status] == [True] + + def test_enabled_large(self, large_model): + layer_status = large_model.get_layer_status() + result = [status.enabled for status in layer_status] + expected = [True, True, True, True] + assert result == expected + + def test_enabled_irregular(self, large_model): + # this is an invalid state, but we should still test it + # disable a single layer + for module in large_model.modules(): + if isinstance(module, BaseTunerLayer): + module.enable_adapters(False) + break + + layer_status = large_model.get_layer_status() + result = [status.enabled for status in layer_status] + expected = [False, True, True, True] + assert result == expected + + def test_active_adapters_small(self, small_model): + layer_status = small_model.get_layer_status() + assert [status.active_adapters for status in layer_status] == [["default"]] + + def test_active_adapters_large(self, large_model): + layer_status = large_model.get_layer_status() + result = [status.active_adapters for status in layer_status] + # note: as currently implemented, the active adapter can be an adapter that does not exist on this specific + # layer, for instance, layer 3 (i.e. index 2) only has the "other" adapter but "default" is still shown as the + # active adapter + expected = [["default"], ["default"], ["default"], ["default"]] + assert result == expected + + # switch to "other" + large_model.set_adapter("other") + layer_status = large_model.get_layer_status() + result = [status.active_adapters for status in layer_status] + expected = [["other"], ["other"], ["other"], ["other"]] + + def test_merge_adapters_small(self, small_model): + layer_status = small_model.get_layer_status() + assert [status.merged_adapters for status in layer_status] == [[]] + assert [status.available_adapters for status in layer_status] == [["default"]] + + # now merge "default" + small_model.merge_adapter(["default"]) + layer_status = small_model.get_layer_status() + assert [status.merged_adapters for status in layer_status] == [["default"]] + assert [status.available_adapters for status in layer_status] == [["default"]] + + def test_merge_adapters_large(self, large_model): + layer_status = large_model.get_layer_status() + result = [status.merged_adapters for status in layer_status] + assert result == [[], [], [], []] + + # now merge "default" + large_model.merge_adapter(["default"]) + layer_status = large_model.get_layer_status() + result = [status.merged_adapters for status in layer_status] + # default is on layer 0, 1, and 3 + assert result == [["default"], ["default"], [], ["default"]] + + # now merge "other" + large_model.unmerge_adapter() + large_model.merge_adapter(["other"]) + layer_status = large_model.get_layer_status() + result = [status.merged_adapters for status in layer_status] + # other is on layer 0 and 2 + assert result == [["other"], [], ["other"], []] + + # now merge both + large_model.merge_adapter(["default", "other"]) + layer_status = large_model.get_layer_status() + result = [status.merged_adapters for status in layer_status] + # default is on layer 0, 1, and 3, other is on layer 0 and 2 + assert result == [["other", "default"], ["default"], ["other"], ["default"]] + + def test_requires_grad_small(self, small_model): + layer_status = small_model.get_layer_status() + assert [status.requires_grad for status in layer_status] == [{"default": True}] + + def test_requires_grad_large(self, large_model): + layer_status = large_model.get_layer_status() + result = [status.requires_grad for status in layer_status] + # default is on layer 0, 1, and 3, other is on layer 0 and 2 + expected = [{"default": True, "other": False}, {"default": True}, {"other": False}, {"default": True}] + assert result == expected + + # now activate "other" + large_model.set_adapter("other") + layer_status = large_model.get_layer_status() + result = [status.requires_grad for status in layer_status] + expected = [{"default": False, "other": True}, {"default": False}, {"other": True}, {"default": False}] + assert result == expected + + def test_requires_grad_irregular(self, large_model): + # inject an embedding layer with requires_grad=False + # this is an invalid state, but we should still test it + lora_embedding_A = nn.Parameter(torch.zeros(10, 10)) + lora_embedding_B = nn.Parameter(torch.zeros(10, 10)) + lora_embedding_A.requires_grad = False + lora_embedding_B.requires_grad = False + large_model.base_model.model.lin0.lora_embedding_A["default"] = lora_embedding_A + large_model.base_model.model.lin0.lora_embedding_B["default"] = lora_embedding_B + + layer_status = large_model.get_layer_status() + result = [status.requires_grad for status in layer_status] + expected = [{"default": "irregular", "other": False}, {"default": True}, {"other": False}, {"default": True}] + assert result == expected + + def test_available_adapters_small(self, small_model): + layer_status = small_model.get_layer_status() + result = [status.available_adapters for status in layer_status] + expected = [["default"]] + assert result == expected + + def test_available_adapters_large(self, large_model): + layer_status = large_model.get_layer_status() + result = [status.available_adapters for status in layer_status] + expected = [["default", "other"], ["default"], ["other"], ["default"]] + assert result == expected + + ################ + # model status # + ################ + + def test_base_model_type_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.base_model_type == "SmallModel" + + def test_base_model_type_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.base_model_type == "LargeModel" + + def test_base_model_type_transformers_automodel(self): + # ensure that this also works with transformers AutoModels + model_id = "google/flan-t5-small" + model = AutoModel.from_pretrained(model_id) + model = get_peft_model(model, LoraConfig()) + model_status = model.get_model_status() + assert model_status.base_model_type == "T5Model" + + def test_adapter_model_type_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.adapter_model_type == "LoraModel" + + def test_adapter_model_type_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.adapter_model_type == "LoraModel" + + def test_peft_types_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.peft_types == {"default": "LORA"} + + def test_peft_types_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.peft_types == {"default": "LORA", "other": "LORA"} + + def test_nb_params_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.trainable_params == 160 + assert model_status.total_params == 380 + + def test_nb_params_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.trainable_params == 616 + assert model_status.total_params == 2236 + + def test_num_adapter_layers_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.num_adapter_layers == 1 + + def test_num_adapter_layers_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.num_adapter_layers == 4 + + def test_model_enabled_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.enabled is True + + def test_model_enabled_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.enabled is True + + def test_model_disabled_small(self, small_model): + small_model.disable_adapter_layers() + model_status = small_model.get_model_status() + assert model_status.enabled is False + + def test_model_disabled_large(self, large_model): + large_model.disable_adapter_layers() + model_status = large_model.get_model_status() + assert model_status.enabled is False + + def test_model_enabled_irregular(self, large_model): + # this is an invalid state, but we should still test it + # disable a single layer + for module in large_model.modules(): + if isinstance(module, BaseTunerLayer): + module.enable_adapters(False) + break + + model_status = large_model.get_model_status() + assert model_status.enabled == "irregular" + + def test_model_active_adapters_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.active_adapters == ["default"] + + def test_model_active_adapters_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.active_adapters == ["default"] + + large_model.set_adapter("other") + model_status = large_model.get_model_status() + assert model_status.active_adapters == ["other"] + + def test_model_active_adapters_irregular(self, large_model): + # this is an invalid state, but we should still test it + # disable a single layer + for module in large_model.modules(): + if isinstance(module, BaseTunerLayer): + # switch a single layer's active adapter from default to other + if module.active_adapters == ["default"]: + module._active_adapter = "other" + assert module.active_adapters == ["other"] + break + + model_status = large_model.get_model_status() + assert model_status.active_adapters == "irregular" + + def test_model_merged_adapters_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.merged_adapters == [] + + small_model.merge_adapter() + model_status = small_model.get_model_status() + assert model_status.merged_adapters == ["default"] + + small_model.unmerge_adapter() + model_status = small_model.get_model_status() + assert model_status.merged_adapters == [] + + def test_model_merged_adapters_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.merged_adapters == [] + + large_model.merge_adapter(["default"]) + model_status = large_model.get_model_status() + assert model_status.merged_adapters == ["default"] + + large_model.unmerge_adapter() + large_model.merge_adapter(["other"]) + model_status = large_model.get_model_status() + assert model_status.merged_adapters == ["other"] + + large_model.unmerge_adapter() + large_model.merge_adapter(["default", "other"]) + model_status = large_model.get_model_status() + assert model_status.merged_adapters == ["default", "other"] + + def test_model_merged_adapters_irregular(self, large_model): + # this is an invalid state, but we should still test it + # by merging only lin0 of "default", we end up in a irregular state, because not all "default" layers are merged + large_model.base_model.lin0.merge(["default"]) + + model_status = large_model.get_model_status() + assert model_status.merged_adapters == "irregular" + + def test_model_requires_grad_model_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.requires_grad == {"default": True} + + def test_model_requires_grad_model_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.requires_grad == {"default": True, "other": False} + + large_model.set_adapter("other") + model_status = large_model.get_model_status() + assert model_status.requires_grad == {"default": False, "other": True} + + def test_model_requires_grad_model_irregular(self, large_model): + # inject an embedding layer with requires_grad=False + # this is an invalid state, but we should still test it + lora_embedding_A = nn.Parameter(torch.zeros(10, 10)) + lora_embedding_B = nn.Parameter(torch.zeros(10, 10)) + lora_embedding_A.requires_grad = False + lora_embedding_B.requires_grad = False + large_model.base_model.model.lin0.lora_embedding_A["default"] = lora_embedding_A + large_model.base_model.model.lin0.lora_embedding_B["default"] = lora_embedding_B + + model_status = large_model.get_model_status() + assert model_status.requires_grad == {"default": "irregular", "other": False} + + def test_model_available_adapters_small(self, small_model): + model_status = small_model.get_model_status() + assert model_status.available_adapters == ["default"] + + def test_model_available_adapters_large(self, large_model): + model_status = large_model.get_model_status() + assert model_status.available_adapters == ["default", "other"] + + def test_loha_model(self): + # ensure that this also works with non-LoRA, it's enough to test one other tuner + class SmallModel(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = nn.Linear(10, 10) + self.lin1 = nn.Linear(10, 10) + + base_model = SmallModel() + config = LoHaConfig(target_modules=["lin0", "lin1"], init_weights=False) + model = get_peft_model(base_model, config) + + model_status = get_model_status(model) + layer_status = get_layer_status(model) + + assert model_status.base_model_type == "SmallModel" + assert model_status.adapter_model_type == "LoHaModel" + assert model_status.peft_types == {"default": "LOHA"} + assert model_status.trainable_params == 640 + assert model_status.total_params == 860 + assert model_status.num_adapter_layers == 2 + assert model_status.enabled is True + assert model_status.active_adapters == ["default"] + assert model_status.merged_adapters == [] + assert model_status.requires_grad == {"default": True} + assert model_status.available_adapters == ["default"] + + layer_status0 = layer_status[0] + assert len(layer_status) == 2 + assert layer_status0.name == "model.lin0" + assert layer_status0.module_type == "loha.Linear" + assert layer_status0.enabled is True + assert layer_status0.active_adapters == ["default"] + assert layer_status0.merged_adapters == [] + assert layer_status0.requires_grad == {"default": True} + assert layer_status0.available_adapters == ["default"] + + ################### + # non-PEFT models # + ################### + + def test_transformers_model(self): + model_id = "peft-internal-testing/gpt2-lora-random" + # note that loading through AutoModelForCausalLM.from_pretrained does not enable training mode, hence + # requires_grad=False + model = AutoModelForCausalLM.from_pretrained(model_id) + model_status = get_model_status(model) + layer_status = get_layer_status(model) + + assert model_status.base_model_type == "GPT2LMHeadModel" + assert model_status.adapter_model_type == "None" + assert model_status.peft_types == {} + assert model_status.trainable_params == 0 + assert model_status.total_params == 124734720 + assert model_status.num_adapter_layers == 12 + assert model_status.enabled is True + assert model_status.active_adapters == ["default"] + assert model_status.merged_adapters == [] + assert model_status.requires_grad == {"default": False} + assert model_status.available_adapters == ["default"] + + layer_status0 = layer_status[0] + assert len(layer_status) == 12 + assert layer_status0.name == "transformer.h.0.attn.c_attn" + assert layer_status0.module_type == "lora.Linear" + assert layer_status0.enabled is True + assert layer_status0.active_adapters == ["default"] + assert layer_status0.merged_adapters == [] + assert layer_status0.requires_grad == {"default": False} + assert layer_status0.available_adapters == ["default"] + + def test_model_with_injected_layers(self, large_model): + model = large_model.base_model.model + model_status = get_model_status(model) + layer_status = get_layer_status(model) + + assert model_status.base_model_type == "other" + assert model_status.adapter_model_type == "None" + assert model_status.peft_types == {} + assert model_status.trainable_params == 616 + assert model_status.total_params == 2236 + assert model_status.num_adapter_layers == 4 + assert model_status.enabled is True + assert model_status.active_adapters == ["default"] + assert model_status.merged_adapters == [] + assert model_status.requires_grad == {"default": True, "other": False} + assert model_status.available_adapters == ["default", "other"] + + layer_status1 = layer_status[1] + assert len(layer_status) == 4 + assert layer_status1.name == "emb0" + assert layer_status1.module_type == "lora.Embedding" + assert layer_status1.enabled is True + assert layer_status1.active_adapters == ["default"] + assert layer_status1.merged_adapters == [] + assert layer_status1.requires_grad == {"default": True} + assert layer_status1.available_adapters == ["default"] + + ############### + # error cases # + ############### + + def test_vanilla_model_raises(self): + model = nn.Linear(10, 10) + # note: full error message is longer + with pytest.raises(ValueError, match="No adapter layers found in the model"): + get_layer_status(model) + + with pytest.raises(ValueError, match="No adapter layers found in the model"): + get_model_status(model) + + def test_transformer_model_without_adapter_raises(self): + model = AutoModelForCausalLM.from_pretrained("gpt2") + # note: full error message is longer + with pytest.raises(ValueError, match="No adapter layers found in the model"): + get_layer_status(model) + + with pytest.raises(ValueError, match="No adapter layers found in the model"): + get_model_status(model) + + def test_prefix_tuning(self): + model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + config = PromptTuningConfig(task_type="SEQ_2_SEQ_LM", num_virtual_tokens=10) + model = get_peft_model(model, config) + + # note: full error message is longer + with pytest.raises(TypeError, match=re.escape("get_layer_status() got an invalid PeftModel instance")): + model.get_layer_status() + + with pytest.raises(TypeError, match=re.escape("get_model_status() got an invalid PeftModel instance")): + model.get_model_status() + + def test_adaption_prompt(self): + model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/tiny-random-LlamaForCausalLM") + config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4) + model = get_peft_model(model, config) + + # note: full error message is longer + with pytest.raises(TypeError, match=re.escape("get_layer_status() got an invalid PeftModel instance")): + model.get_layer_status() + + with pytest.raises(TypeError, match=re.escape("get_model_status() got an invalid PeftModel instance")): + model.get_model_status() + + def test_mixed_model_raises(self): + class SimpleNet(nn.Module): + def __init__(self, bias=True): + super().__init__() + # note: out_features must be > rank or else OFT will be an identity transform + self.lin0 = nn.Linear(10, 20, bias=bias) + self.relu = nn.ReLU() + self.lin1 = nn.Linear(20, 16, bias=bias) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + return X + + base_model = SimpleNet() + config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False) + config1 = LoHaConfig(target_modules=["lin0", "lin1"], init_weights=False) + model = get_peft_model(base_model, config0, adapter_name="adapter0", mixed="mixed") + model.add_adapter("adapter1", config1) + + # note: full error message is longer + with pytest.raises(TypeError, match="get_layer_status is not supported for PeftMixedModel"): + model.get_layer_status() + + with pytest.raises(TypeError, match="get_model_status is not supported for PeftMixedModel"): + model.get_model_status() From e19f7bf424bae38263a3ec73e0e6ea6e7590c153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ji=C5=99=C3=AD=20Podiv=C3=ADn?= <66251151+jpodivin@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:23:46 +0200 Subject: [PATCH 3/4] FIX Doc error prompt tuning seq len calc (#1686) Signed-off-by: Jiri Podivin --- docs/source/task_guides/prompt_based_methods.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/task_guides/prompt_based_methods.md b/docs/source/task_guides/prompt_based_methods.md index d93edfacb4..c5e61e28c6 100644 --- a/docs/source/task_guides/prompt_based_methods.md +++ b/docs/source/task_guides/prompt_based_methods.md @@ -90,7 +90,7 @@ def preprocess_function(examples, text_column="Tweet text", label_column="text_l model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[ "attention_mask" ][i] - labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids + labels["input_ids"][i] = [-100] * (max_length - len(label_input_ids)) + label_input_ids model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length]) model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length]) labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length]) From 608a90ded9985ee1c5912d738082bb1fd618902b Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 29 Apr 2024 18:27:13 +0200 Subject: [PATCH 4/4] TST: Skiping AWQ tests for now .. (#1690) * Update test_gpu_examples.py * Update tests/test_gpu_examples.py --- tests/test_gpu_examples.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 91b09a0fbd..f2c2ae63d7 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1930,8 +1930,10 @@ def test_causal_lm_training_aqlm(self): assert trainer.state.log_history[-1]["train_loss"] is not None +# TODO: unskip the tests once https://github.com/casper-hansen/AutoAWQ/issues/466 is fixed @require_torch_gpu @require_auto_awq +@pytest.mark.skip(reason="Needs https://github.com/casper-hansen/AutoAWQ/issues/466 to be fixed first") class PeftAwqGPUTests(unittest.TestCase): r""" Awq + peft tests