Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: LoRA support for dynamically dispatching to custom layers #1875

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jun 19, 2024

Resolves #1867

Description

This is an experimental feature with a private API for now. If this feature finds adoption, I will work on adding an official API.

With this PR, we allow users to register their own LoRA layer types. This way, they can add their own support for hitherto unsupported layer types, say nn.Conv3d or nn.LSTM. Without this PR, they can only do that by creating a PR on PEFT with support for this new type and getting it merged.

The custom dispatch mechanism also allows users to override existing layer type mapping. This way, they can, for instance, provide their own lora.Linear layer type, instead of using the one from PEFT, to adapt nn.Linear layers.

Implementation

The implementation required only very few changes because we already have a mechanism for dynamic dispatching for LoRA. It is currently used, for instance, to dynamically add quantized target layers in case the right quantization library is installed.

This existing mechanism is now extended to include user provided LoRA layers if those were passed. These are checked first before checking the default PEFT supported layers.

What's missing for this to become an official API?

Right now, the main reason why this cannot be an official API is the question of how to persist the config. In the current implementation, we add an attribute that is a mapping from target layer type to LoRA layer type:

config._custom_modules == {CustomBaseLayer: CustomLoraLayer}

The entries of this dict are Python classes. Therefore, they cannot be json-serialized. We could think of possible solutions how to serialize and deserialize custom Python objects, but this is not trivial and potentially a security risk. Thus I would only really start working on this if the demand is sufficiently high. At that point, I would also add a public API instead of requiring the use of a private API.

As is, users can still save and load PEFT models with custom LoRA layers, they only need to add two lines of code to their scripts, as documented.

We could also think about adding support for methods other than LoRA. However, this would require to implement the dynamic dispatch mechanism for those other methods, which right now only exists for LoRA.

BenjaminBossan added 2 commits June 19, 2024 14:37
Description

This is an experimental feature with a private API for now. If this
feature finds adoption, I will work on adding an official API.

With this PR, we allow users to register their own LoRA layer types.
This way, they can add their own support for hitherto unsupported layer
types, say nn.Conv3d or nn.LSTM. Without this PR, they can only do that
by creating a PR on PEFT with support for this new type and getting it
merged.

The custom dispatch mechanism also allows users to override existing
layer type mapping. This way, they can, for instance, provide their own
lora.Linear layer type, instead of using the one from PEFT, to adapt
nn.Linear layers.

Implementation

The implementation required only very few changes because we already
have a mechanism for dynamic dispatching for LoRA. It is currently used,
for instance, to dynamically add quantized target layers in case the
right quantization library is installed.

This existing mechanism is now extended to include user provided LoRA
layers if those were passed. These are checked first before checking the
default PEFT supported layers.

What's missing for this to become an official API?

Right now, the main reason why this cannot be an official API is the
question of how to persist the config. In the current implementation, we
add an attribute that is a mapping from target layer type to LoRA layer
type:

config._custom_modules == {CustomBaseLayer: CustomLoraLayer}

The entries of this dict are Python classes. Therefore, they cannot be
json-serialized. We could think of possible solutions how to serialize
and deserialize custom Python objects, but this is not trivial and
potentially a security risk. Thus I would only really start working on
this if the demand is sufficiently high. At that point, I would also add
a public API instead of requiring the use of a private API.

As is, users can still save and load PEFT models with custom LoRA
layers, they only need to add two lines of code to their scripts, as
documented.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan BenjaminBossan changed the title Enh lora dynamic dispatch custom layers ENH: LoRA support for dynamically dispatching to custom layers Jun 19, 2024
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this clean integration, docs and API ! Left only one nit but overall looks great !


When creating your custom LoRA module, please follow the same rules as the existing LoRA modules do. For this, check the [LoRA layer implementation](https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py). Notable constraints to consider:

- The custom module should inherit from `nn.Module` and `peft.tuners.lora.layer.LoraLayer`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we should IMO state that the signature of the init method should have base_layer and adapter_name in the correct order otherwise the API will fail

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I added an entry for the __init__.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Left a comment about improving the introductory section a bit 😄

docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
docs/source/developer_guides/custom_models.md Outdated Show resolved Hide resolved
@BenjaminBossan BenjaminBossan merged commit ef23712 into huggingface:main Jun 25, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the enh-lora-dynamic-dispatch-custom-layers branch June 25, 2024 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FEAT: LoRA: Possibility to dynamically register layers to dispatch
4 participants