Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DTennant committed Apr 30, 2024
2 parents 8b40d23 + 608a90d commit 4c13b0e
Show file tree
Hide file tree
Showing 13 changed files with 1,181 additions and 19 deletions.
105 changes: 105 additions & 0 deletions docs/source/developer_guides/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,108 @@ model.save_adapter("my_adapter", save_embedding_layers=True)
For inference, load the base model first and resize it the same way you did before you trained the model. After you've resized the base model, you can load the PEFT checkpoint.

For a complete example, please check out [this notebook](https://github.com/huggingface/peft/blob/main/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb).

### Check layer and model status

Sometimes a PEFT model can end up in a bad state, especially when handling multiple adapters. There can be some confusion around what adapters exist, which one is active, which one is merged, etc. To help investigate this issue, call the [`~peft.PeftModel.get_layer_status`] and the [`~peft.PeftModel.get_model_status`] methods.

The [`~peft.PeftModel.get_layer_status`] method gives you a detailed overview of each targeted layer's active, merged, and available adapters.

```python
>>> from transformers import AutoModel
>>> from peft import get_peft_model, LoraConfig

>>> model_id = "google/flan-t5-small"
>>> model = AutoModel.from_pretrained(model_id)
>>> model = get_peft_model(model, LoraConfig())

>>> model.get_layer_status()
[TunerLayerStatus(name='model.encoder.block.0.layer.0.SelfAttention.q',
module_type='lora.Linear',
enabled=True,
active_adapters=['default'],
merged_adapters=[],
requires_grad={'default': True},
available_adapters=['default']),
TunerLayerStatus(name='model.encoder.block.0.layer.0.SelfAttention.v',
module_type='lora.Linear',
enabled=True,
active_adapters=['default'],
merged_adapters=[],
requires_grad={'default': True},
available_adapters=['default']),
...]

>>> model.get_model_status()
TunerModelStatus(
base_model_type='T5Model',
adapter_model_type='LoraModel',
peft_types={'default': 'LORA'},
trainable_params=344064,
total_params=60855680,
num_adapter_layers=48,
enabled=True,
active_adapters=['default'],
merged_adapters=[],
requires_grad={'default': True},
available_adapters=['default'],
)
```

In the model state output, you should look out for entries that say `"irregular"`. This means PEFT detected an inconsistent state in the model. For instance, if `merged_adapters="irregular"`, it means that for at least one adapter, it was merged on some target modules but not on others. The inference results will most likely be incorrect as a result.

The best way to resolve this issue is to reload the whole model and adapter checkpoint(s). Ensure that you don't perform any incorrect operations on the model, e.g. manually merging adapters on some modules but not others.

Convert the layer status into a pandas `DataFrame` for an easier visual inspection.

```python
from dataclasses import asdict
import pandas as pd

df = pd.DataFrame(asdict(layer) for layer in model.get_layer_status())
```

It is possible to get this information for non-PEFT models if they are using PEFT layers under the hood, but some information like the `base_model_type` or the `peft_types` cannot be determined in that case. As an example, you can call this on a [diffusers](https://huggingface.co/docs/diffusers/index) model like so:

```python
>>> import torch
>>> from diffusers import StableDiffusionPipeline
>>> from peft import get_model_status, get_layer_status

>>> path = "runwayml/stable-diffusion-v1-5"
>>> lora_id = "takuma104/lora-test-text-encoder-lora-target"
>>> pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
>>> pipe.load_lora_weights(lora_id, adapter_name="adapter-1")
>>> pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
>>> get_layer_status(pipe.text_encoder)
[TunerLayerStatus(name='text_model.encoder.layers.0.self_attn.k_proj',
module_type='lora.Linear',
enabled=True,
active_adapters=['adapter-2'],
merged_adapters=[],
requires_grad={'adapter-1': False, 'adapter-2': True},
available_adapters=['adapter-1', 'adapter-2']),
TunerLayerStatus(name='text_model.encoder.layers.0.self_attn.v_proj',
module_type='lora.Linear',
enabled=True,
active_adapters=['adapter-2'],
merged_adapters=[],
requires_grad={'adapter-1': False, 'adapter-2': True},
available_adapters=['adapter-1', 'adapter-2']),
...]

>>> get_model_status(pipe.unet)
TunerModelStatus(
base_model_type='other',
adapter_model_type='None',
peft_types={},
trainable_params=797184,
total_params=861115332,
num_adapter_layers=128,
enabled=True,
active_adapters=['adapter-2'],
merged_adapters=[],
requires_grad={'adapter-1': False, 'adapter-2': True},
available_adapters=['adapter-1', 'adapter-2'],
)
```
4 changes: 4 additions & 0 deletions docs/source/package_reference/peft_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ A `PeftModel` for mixing different adapter types (e.g. LoRA and LoHa).
[[autodoc]] utils.get_peft_model_state_dict

[[autodoc]] utils.prepare_model_for_kbit_training

[[autodoc]] get_layer_status

[[autodoc]] get_model_status
2 changes: 1 addition & 1 deletion docs/source/task_guides/prompt_based_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def preprocess_function(examples, text_column="Tweet text", label_column="text_l
model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[
"attention_mask"
][i]
labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids
labels["input_ids"][i] = [-100] * (max_length - len(label_input_ids)) + label_input_ids
model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
Expand Down
2 changes: 2 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
PeftModelForTokenClassification,
PeftModelForQuestionAnswering,
PeftModelForFeatureExtraction,
get_layer_status,
get_model_status,
)
from .tuners import (
AdaptionPromptConfig,
Expand Down
6 changes: 6 additions & 0 deletions src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def unload(self, *args: Any, **kwargs: Any):
"""
return self.base_model.unload(*args, **kwargs)

def get_layer_status(self):
raise TypeError(f"get_layer_status is not supported for {self.__class__.__name__}.")

def get_model_status(self):
raise TypeError(f"get_model_status is not supported for {self.__class__.__name__}.")

@classmethod
def _split_kwargs(cls, kwargs: dict[str, Any]):
return PeftModel._split_kwargs(kwargs)
Expand Down
Loading

0 comments on commit 4c13b0e

Please sign in to comment.