Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Peft model signature #784

Merged
merged 18 commits into from
Aug 10, 2023
Merged

Conversation

kiansierra
Copy link
Contributor

#783
This PR modifies the signature of the default PeftModel.forward to be identical to the one in the base_model

from transformers import WhisperForConditionalGeneration
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

model_name_or_path = "openai/whisper-large-v2"
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True)
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
peft_model = get_peft_model(model, config)
??peft_model.forward

The output of the above cell is the below which would be the same as for the original model

Signature: model.forward(input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple[torch.Tensor], transformers.modeling_outputs.Seq2SeqLMOutput]
Docstring:
The [`WhisperForConditionalGeneration`] forward method, overrides the `__call__` special method.

<Tip>

Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.

</Tip>

Args:
    input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
        Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
        loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
        the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
        [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
        tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
    attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
        `[0, 1]`:

        - 1 for tokens that are **not masked**,
        - 0 for tokens that are **masked**.

        [What are attention masks?](../glossary#attention-mask)
    decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
        Indices of decoder input sequence tokens in the vocabulary.

        Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
        [`PreTrainedTokenizer.__call__`] for details.

        [What are decoder input IDs?](../glossary#decoder-input-ids)

        Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
        `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
        `past_key_values`).
    decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
        Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
        be used by default.

        If you want to change padding behavior, you should read
        [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
        paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
    head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
        Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

        - 1 indicates the head is **not masked**,
        - 0 indicates the head is **masked**.

    decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
        Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:

        - 1 indicates the head is **not masked**,
        - 0 indicates the head is **masked**.

    cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
        Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

        - 1 indicates the head is **not masked**,
        - 0 indicates the head is **masked**.

    encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
        Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
        `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
        hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
        `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
        `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

        Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
        blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

        If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
        don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
        `decoder_input_ids` of shape `(batch_size, sequence_length)`.
    decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
        Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
        representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
        input (see `past_key_values`). This is useful if you want more control over how to convert
        `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
    use_cache (`bool`, *optional*):
        If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
        `past_key_values`).
    output_attentions (`bool`, *optional*):
        Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
        tensors for more detail.
    output_hidden_states (`bool`, *optional*):
        Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
        more detail.
    return_dict (`bool`, *optional*):
        Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
        or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
        only computed for the tokens with labels in `[0, ..., config.vocab_size]`.


    Returns:
        [`transformers.modeling_outputs.Seq2SeqLMOutput`] or `tuple(torch.FloatTensor)`: A [`transformers.modeling_outputs.Seq2SeqLMOutput`] or a tuple of
        `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
        elements depending on the configuration ([`WhisperConfig`]) and inputs.

        - **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) -- Language modeling loss.
        - **logits** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) -- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        - **past_key_values** (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
          `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
          `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

          Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
          blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
        - **decoder_hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
          one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

          Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
        - **decoder_attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
          sequence_length)`.

          Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
          self-attention heads.
        - **cross_attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
          sequence_length)`.

          Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
          weighted average in the cross-attention heads.
        - **encoder_last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) -- Sequence of hidden-states at the output of the last layer of the encoder of the model.
        - **encoder_hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
          one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

          Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        - **encoder_attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
          sequence_length)`.

          Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
          self-attention heads.
  

    Example:

    ```python
    >>> import torch
    >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
    >>> from datasets import load_dataset

    >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
    >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

    >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
    >>> input_features = inputs.input_features

    >>> generated_ids = model.generate(inputs=input_features)

    >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    >>> transcription
    ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
    ```
Source:   
        def new_forward(self, *args, **kwargs):
            return self.get_base_model()(*args, **kwargs)
File:      /usr/local/lib/python3.10/dist-packages/peft/peft_model.py
Type:      method

@BenjaminBossan
Copy link
Member

Thanks for the PR. I see where you're coming from and think there is value in your suggestion. I'm always a little bit careful when working with inspect, local functions, overwriting attributes, and descriptors, as there can be strange edge cases where they fail.

When I your suggestion, I immediately thought that functools.update_wrapper could be used instead, which I would prefer, since it's from the standard lib, so it's already battle tested. My idea was to basically do something like this:

# inside PeftModel.__init__
update_wrapper(self.forward, self.get_base_model().forward)

but that actually raises an error. I tried a few different things around that idea but nothing works. It seems to be somehow related to those methods being instance methods, but I'm really stumped. Maybe you have an idea what the issue is?

I tried to isolate the issue but at the end of the day, I'm really non the wiser:

# normal functions work
def spam(x, y=2) -> int:
    """spam"""
    return x - y
def eggs(x, y=3) -> float:
    """eggs"""
    return x + y + 1.0
update_wrapper(spam, eggs)
help(spam)
# prints
# Help on function eggs in module __main__:
# eggs(x, y=3) -> float
#     eggs

# unbound methods work
class SpamAndEggs:
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        return x + y + 1.0
update_wrapper(SpamAndEggs.spam, SpamAndEggs.eggs)
help(SpamAndEggs().spam)
# prints
# Help on method eggs in module __main__:
# eggs(x, y=3) -> float method of __main__.SpamAndEggs instance
#     eggs

# bound instance methods don't work
class SpamAndEggs:
    def __init__(self):
        update_wrapper(self.spam, self.eggs)
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        return x + y + 1.0
help(SpamAndEggs().spam)
# raises
# AttributeError: 'method' object has no attribute '__module__'
# even though the method actually *does* have an attribute '__module__' !!!

So only the last case with bound instance methods fails, but that would be exactly what we need here. I really feel like there must be a simple solution and I'm just not seeing it, but I couldn't get it to work and my google-foo failed me.

@kiansierra
Copy link
Contributor Author

kiansierra commented Aug 3, 2023

It seems the main issue with the above approach (skipping the __module__ part, since we probably don't want to overwrite that)

from functools import update_wrapper

class SpamAndEggs:
    def __init__(self):
        update_wrapper(self.eggs, self.spam, assigned=('__doc__', '__name__', '__annotations__'))
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        
        return x + y + 1.0
help(SpamAndEggs().eggs)
# raises
# AttributeError: attribute '__doc__' of 'method' objects is not writable'

which occurs again if you do

class SpamAndEggs:
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        
        return x + y + 1.0
SpamAndEggs().eggs.__doc__ = 'spam'

but this variant does work

from functools import update_wrapper, wraps

class SpamAndEggs:
    def __init__(self):
        self._update_wrapper()
        pass
    def _update_wrapper(self):
        def new_eggs(*args, **kwargs):
            return self.bacon(*args, **kwargs)
        update_wrapper(new_eggs, self.spam, assigned=('__doc__', '__name__', '__annotations__'))
        self.eggs = new_eggs
        
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        
        return x + y + 1.0
    
    def bacon(self, x, y=3) -> float:
        """bacon"""
        
        return x + y + 1.0
help(SpamAndEggs().eggs)
print(f"6 Eggs: {SpamAndEggs().eggs(6)}")
# Output: 
Help on function spam in module __main__:

spam(x, y=2) -> int
    spam

6 Eggs: 10.0

It has bacon mehtod to avoid recursion limit as new_eggs would call itself when it is overwritten in `self.eggs = new_eggs

This implementation seems to work on PEFT too

def _update_forward_signature(self):

        def new_forward(*args, **kwargs):
            return self.get_base_model()(*args, **kwargs)
        
        update_wrapper(new_forward, self.base_model.forward, assigned=('__doc__', '__name__', '__annotations__'))
        self.forward = new_forward

I've run the tests locally and the following fail with current implementation and proposed above changes, but they also fail in the main branch

FAILED tests/test_config.py::PeftConfigTester::test_prompt_encoder_warning_num_layers - AssertionError: assert 'for PromptEn...ers are used.' == 'for MLP, the.....
FAILED tests/test_gpu_examples.py::PeftBnbGPUExampleTests::test_4bit_adalora_causalLM - ValueError: You can't train a model that has been loaded in 8-bit precis...
FAILED tests/test_stablediffusion.py::StableDiffusionModelTester::test_disable_adapter_0_test_hf_internal_testing_tiny_stable_diffusion_torch_lora - AssertionError: tensor(False) is not true

I had to comment out addopts = "--cov=src/peft --cov-report=term-missing" from pyproject.toml

@BenjaminBossan
Copy link
Member

but this variant does work

Thanks for investigating. Yes, your proposal works, I could reproduce it, but unfortunately it also requires a locally defined function, which is one of the things I wanted to avoid. E.g. this fails now:

import pickle
_ = pickle.dumps(SpamAndEggs())

So the PeftModel class, which is currently pickle-able, wouldn't be anymore after the change. That was probably also true for your previous implementation, I haven't tested.

Even though your suggestion is a nice quality of life improvement, I wouldn't want to sacrifice this functionality for it. If we find some way to make it work without that compromise, I would be in favor.

@BenjaminBossan
Copy link
Member

I've run the tests locally and the following fail with current implementation and proposed above changes, but they also fail in the main branch

Not sure what the issue is here, but CI is green, so maybe it's something about your local environment.

@kiansierra
Copy link
Contributor Author

I see that is an issue for sure, it can be made pickable by changing the local function to a normal function, but it creates its own set of issues, which can be one of the following depending if you include __name__
Without name, doesn't load new signature

def eggs(self, *args, **kwargs):
    return self.bacon(*args, **kwargs)
        
class SpamAndEggs:
    def __init__(self):
        self._update_wrapper()
        pass
    def _update_wrapper(self):

        update_wrapper(eggs, SpamAndEggs.spam, assigned=('__doc__',  '__annotations__'))
        self.eggs = MethodType(eggs, self)
        
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        
        return x + y + 1.0
    
    def bacon(self, x, y=3) -> float:
        """bacon"""
        
        return x + y + 1.0
    
spam_and_eggs = SpamAndEggs()
help(spam_and_eggs.eggs)
help(spam_and_eggs.spam)
print(f"6 Eggs: {spam_and_eggs.eggs(6)}")
print(f"6 spam: {spam_and_eggs.spam(6)}")

with open("test.pkl", "wb") as f:
    pickle.dump(spam_and_eggs, f)
with open("test.pkl", "rb") as f:
    spam_and_eggs = pickle.load(f)
    
help(spam_and_eggs.eggs)
help(spam_and_eggs.spam)
print(f"6 Eggs: {spam_and_eggs.eggs(6)}")
print(f"6 spam: {spam_and_eggs.spam(6)}")

Output:

Help on method eggs in module __main__:

eggs(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

Help on method spam in module __main__:

spam(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

6 Eggs: 10.0
6 spam: 4
Help on method eggs in module __main__:

eggs(x, y=3) -> float method of __main__.SpamAndEggs instance
    eggs

Help on method spam in module __main__:

spam(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

6 Eggs: 10.0
6 spam: 4

While if you include name it uploads the wrong functionality

from functools import update_wrapper, wraps
from types import MethodType
import pickle

def eggs(self, *args, **kwargs):
    return self.bacon(*args, **kwargs)
        
class SpamAndEggs:
    def __init__(self):
        self._update_wrapper()
        pass
    def _update_wrapper(self):

        update_wrapper(eggs, SpamAndEggs.spam, assigned=('__doc__',  '__annotations__', '__name__'))
        self.eggs = MethodType(eggs, self)
        
    def spam(self, x, y=2) -> int:
        """spam"""
        return x - y
    def eggs(self, x, y=3) -> float:
        """eggs"""
        
        return x + y + 1.0
    
    def bacon(self, x, y=3) -> float:
        """bacon"""
        
        return x + y + 1.0
    
spam_and_eggs = SpamAndEggs()
help(spam_and_eggs.eggs)
help(spam_and_eggs.spam)
print(f"6 Eggs: {spam_and_eggs.eggs(6)}")
print(f"6 spam: {spam_and_eggs.spam(6)}")

with open("test.pkl", "wb") as f:
    pickle.dump(spam_and_eggs, f)
with open("test.pkl", "rb") as f:
    spam_and_eggs = pickle.load(f)
    
help(spam_and_eggs.eggs)
help(spam_and_eggs.spam)
print(f"6 Eggs: {spam_and_eggs.eggs(6)}")
print(f"6 spam: {spam_and_eggs.spam(6)}")
Help on method spam in module __main__:

spam(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

Help on method spam in module __main__:

spam(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

6 Eggs: 10.0
6 spam: 4
Help on method spam in module __main__:

spam(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

Help on method spam in module __main__:

spam(x, y=2) -> int method of __main__.SpamAndEggs instance
    spam

6 Eggs: 4
6 spam: 4

I think maybe the simplest solution might be to implement a method in the PeftModel
that allows the user to upload the signature manually
peft_model.update_forward_signature()

@BenjaminBossan
Copy link
Member

Very nice investigation, good idea to look at the functionality after a pickle roundtrip.

I'm still astonished that update_wrapper doesn't just work, I wonder if it's a bug or if I just need to learn more about the Python object model. Especially this code almost makes me believe it's a bug:

from functools import update_wrapper

class Mock:
    def foo(self, x, y=3) -> int:
        """Some docs here"""
        return x - y

class Wrapper:
    def __init__(self, mock):
        self.foo = update_wrapper(self.foo, mock.foo)
    def foo(self, x, y=10) -> float:
        """Other docs"""
        return x + y

wrapper = Wrapper(Mock())
# raises: 
# ---> 56         setattr(wrapper, attr, value)
# AttributeError: 'method' object has no attribute '__module__'

# entering debugger
ipdb>  hasattr(wrapper, '__module__')
True
ipdb>  setattr(wrapper, '__module__', value)
*** AttributeError: 'method' object has no attribute '__module__'

I tried a few variations of this, nothing worked.

Coming back to the issue at hand, I think at this point, the solution is too complicated for what is, at least in theory, a very simple thing.

I wonder if it wouldn't be easier to just update the forward docstring to say something along the lines of:

To get the help for the base model's forward method, please call
    >>> help(peft_model.get_base_model().forward)

As for your PR, I think the code could probably be factored out as a stand-alone function which could be put somewhere (a gist or maybe even PEFT docs) and then we can reference it somewhere for people who really wish to update their docs/signatures etc. This way, people can opt in but the PEFT code itself does not have this somewhat complicated piece of code, which increases maintenance burden.

@kiansierra
Copy link
Contributor Author

Yes, I had suspicions it would not load correctly even if it saved succesfully.
I've written the gist for anyone that might be interested in using it.
I also refactored the function in to utils.other in case that might be somewhere more suited for it.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the updates. Could you please undo the unrelated changes concerning the import order?

Regarding this new function, I'm not sure if it's really a utility function. I think utility functions are more like functions used by other PEFT classes / functions under the hood, i.e. deep down in the call stack. For me, update_forward_signature is more of a helper function, which sits at the top of the call stack. So I would rather introduce a new Python module for helper functions and put it there.

Before putting any more effort into this PR though, I would first like to get the opinion of @pacman100 whether we want to have this functionality in PEFT or not.

src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/utils/other.py Outdated Show resolved Hide resolved
src/peft/utils/other.py Outdated Show resolved Hide resolved
@BenjaminBossan
Copy link
Member

I tried something else which looked like it could work:

class BoundMethod:
    def __init__(self, f):
        self.f = f
        update_wrapper(self, f)

    def __get__(self, obj, objtype=None):
        return self.f

    def __call__(self, *args, **kwds):
        return self.f(*args, **kwds)

# in PeftModel.__init__
        self.forward = BoundMethod(model.forward)

Unfortunately, this manages to somehow detach the forward call, resulting in gradients not being computed :-/

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this puzzled me a lot too. Earlier doing the following in the init self.forward=self.get_base_model.foward seemed to preserve the signature but failed with torch.nn.DataParallel. I like the current solution, please address the comments made by Benjamin and then we should be good to go.

@kiansierra
Copy link
Contributor Author

I've addressed the above comments, let me know if you would like the function to be in a different file so we can add typing def update_forward_signature(model:PeftModel)
If this is good I can try something with generate method

@kiansierra
Copy link
Contributor Author

Also for some reason I currently don't understand, the below code seems to provide the correct signature

from transformers import AutoModelForSeq2SeqLM, WhisperForConditionalGeneration
from peft import  get_peft_model, LoraConfig
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"])
peft_model = get_peft_model(model, peft_config)
print(type(peft_model))
help(peft_model.generate)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to make the update_forward_signature agnostic to the method itself? I.e. update_signature(peft_model, method=forward)? That might require some fiddling with default_forward (maybe partialing it). Or just have two options for forward and generate.

Also, what about my suggestion from earlier:

I would rather introduce a new Python module for helper functions and put it there.

Do you think it makes sense?

Also for some reason I currently don't understand, the below code seems to provide the correct signature

Some methods are just forwarding to the base model's method through __getattr__. In your example, the get_peft_model returns a PeftModel, which doesn't have it's own generate method (and neither has LoraModel), hence the help you see is the help WhisperForConditionalGeneration.

peft/src/peft/peft_model.py

Lines 427 to 432 in 7d44026

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.base_model, name)

@kiansierra
Copy link
Contributor Author

Would it be possible to make the update_forward_signature agnostic to the method itself? I.e. update_signature(peft_model, method=forward)? That might require some fiddling with default_forward (maybe partialing it). Or just have two options for forward and generate.

I think the best approach would be to create a function for each different functionality and then update the signature correspondingly

Also, what about my suggestion from earlier:

I would rather introduce a new Python module for helper functions and put it there.

Do you think it makes sense?

Yes I think it does make a lot of sense, allowing type hints and separting internal utils and external utils. I just wasn't sure if it was a change requirement or something to consider for the future

The current implementation that modifies the generate method will only change the signature in those that have been overriden by a PeftModel, since those whose generate method is pulled forward with __getattr__ already have the correct signature

Thanks for the clarification with __getattr__ I didn't pick up on that

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it does make a lot of sense, allowing type hints and separting internal utils and external utils. I just wasn't sure if it was a change requirement or something to consider for the future

Thanks for making that change. IMO, we don't need to import it into __init__.py, I prefer if users need to explicitly import it as from peft.helpers, this makes it clear that it's a helper function.

The current implementation that modifies the generate method will only change the signature in those that have been overriden by a PeftModel, since those whose generate method is pulled forward with __getattr__ already have the correct signature

Nice.

Something that has me concerned a bit is that we now have significant code duplication for the generate methods. I guess we cannot just use the method from the class because of the issues we discussed earlier, thus requiring the use of functions instead of methods? If so, I wonder if it is not better to drop generate, as the code duplication is a liability (easy to miss to update the functions here).

src/peft/helpers.py Outdated Show resolved Hide resolved
@kiansierra
Copy link
Contributor Author

I think with this change it might be even possible to implement it in the PeftModel __init__ since it doesn't rely on local functions
By using the pattern obj.method.__func__ we unbound the method from the object and we can apply update_wrapper as we initially intented to do, and then bound it back to the object with obj.method = MethodType(method, obj)

@BenjaminBossan
Copy link
Member

I think with this change it might be even possible to implement it in the PeftModel __init__ since it doesn't rely on local functions
By using the pattern obj.method.__func__ we unbound the method from the object and we can apply update_wrapper as we initially intented to do, and then bound it back to the object with obj.method = MethodType(method, obj)

Oh interesting. If we want to apply that by default, it would require some rigorous testing (i.e. added unit tests) to ensure that we don't break anything. If you're willing to do that and other maintainers agree, we can make the change. Otherwise, I'm also happy with the separate helper functions.

(What I still don't understand is what is the issue with methods in general, but that's for another day.)

@kiansierra
Copy link
Contributor Author

I think this works fine for now, if you decide this wants to be implemented by default feel free to tag me, and I'll try to help out

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 10, 2023

The documentation is not available anymore as the PR was closed or merged.

@BenjaminBossan
Copy link
Member

@kiansierra could you please run a make style?

@kiansierra
Copy link
Contributor Author

done, forgot to install the doc-builder

@BenjaminBossan
Copy link
Member

Oh, sorry I missed it, but could you please remove the imports from __init__.py. Reason I mentioned earlier:

IMO, we don't need to import it into init.py, I prefer if users need to explicitly import it as from peft.helpers, this makes it clear that it's a helper function.

@kiansierra
Copy link
Contributor Author

removed, sorry I didn't pick that up I was focused on the code duplication at that time

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, great work and nice investigation of the issue. Thanks a lot.

@kiansierra
Copy link
Contributor Author

Thanks a lot for your feedback and guidance on this topic

@BenjaminBossan BenjaminBossan merged commit 412d7bc into huggingface:main Aug 10, 2023
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants