Skip to content

Commit

Permalink
Helper function to update model signature (#784)
Browse files Browse the repository at this point in the history
Provides helper functions in peft.helpers to update the signature of the
forward or generate method of a PeftModel (or subclass). This can be
useful because the wrapping class may override the docstring and type
annotations of the underlying base model. Applying the helper functions
will restore those, leading to better tab completion, help text, etc.

For the time being, these helper functions are purely optional to use.
At a later stage, we may consider applying them automatically, but that
would require testing to ensure that nothing breaks.
  • Loading branch information
kiansierra authored Aug 10, 2023
1 parent 7d44026 commit 412d7bc
Showing 1 changed file with 113 additions and 0 deletions.
113 changes: 113 additions & 0 deletions src/peft/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import inspect
from copy import deepcopy
from functools import update_wrapper
from types import MethodType

from .peft_model import PeftModel


def update_forward_signature(model: PeftModel):
"""
Args:
Updates the forward signature of the PeftModel to include parents class signature
model (`PeftModel`): Peft model to update the forward signature
Example:
```python
>>> from transformers import WhisperForConditionalGeneration
>>> from peft import get_peft_model, LoraConfig, update_forward_signature
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"])
>>> peft_model = get_peft_model(model, peft_config)
>>> update_forward_signature(peft_model)
```
"""

# Only update signature when the current forward signature only has *args and **kwargs
current_signature = inspect.signature(model.forward)
if (
len(current_signature.parameters) == 2
and "args" in current_signature.parameters
and "kwargs" in current_signature.parameters
):
forward = deepcopy(model.forward.__func__)
update_wrapper(
forward, type(model.get_base_model()).forward, assigned=("__doc__", "__name__", "__annotations__")
)
model.forward = MethodType(forward, model)


def update_generate_signature(model: PeftModel):
"""
Args:
Updates the generate signature of a PeftModel with overriding generate to include parents class signature
model (`PeftModel`): Peft model to update the generate signature
Example:
```python
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> from peft import get_peft_model, LoraConfig, TaskType, update_generate_signature
>>> model_name_or_path = "bigscience/mt0-large"
>>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
>>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
>>> peft_config = LoraConfig(
... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
... )
>>> peft_model = get_peft_model(model, peft_config)
>>> update_generate_signature(peft_model)
>>> help(peft_model.generate)
```
"""
if not hasattr(model, "generate"):
return
current_signature = inspect.signature(model.generate)
if (
len(current_signature.parameters) == 2
and "args" in current_signature.parameters
and "kwargs" in current_signature.parameters
) or (len(current_signature.parameters) == 1 and "kwargs" in current_signature.parameters):
generate = deepcopy(model.generate.__func__)
update_wrapper(
generate,
type(model.get_base_model()).generate,
assigned=("__doc__", "__name__", "__annotations__"),
)
model.generate = MethodType(generate, model)


def update_signature(model: PeftModel, method: str = "all"):
"""
Args:
Updates the signature of a PeftModel include parents class signature for forward or generate method
model (`PeftModel`): Peft model to update generate or forward signature method (`str`): method to update
signature choose one of "forward", "generate", "all"
Example:
```python
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> from peft import get_peft_model, LoraConfig, TaskType, update_signature
>>> model_name_or_path = "bigscience/mt0-large"
>>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
>>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
>>> peft_config = LoraConfig(
... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
... )
>>> peft_model = get_peft_model(model, peft_config)
>>> update_signature(peft_model)
>>> help(peft_model.generate)
```
"""
if method == "forward":
update_forward_signature(model)
elif method == "generate":
update_generate_signature(model)
elif method == "all":
update_forward_signature(model)
update_generate_signature(model)
else:
raise ValueError(f"method {method} is not supported please choose one of ['forward', 'generate', 'all']")

0 comments on commit 412d7bc

Please sign in to comment.