Skip to content

Commit

Permalink
ENH Multi adapters in same batch: modules_to_save (#1990)
Browse files Browse the repository at this point in the history
Extend the functionality of having different adapters in the same batch to also
work with `modules_to_save`.
  • Loading branch information
saeid93 authored Sep 17, 2024
1 parent 18f3efe commit adf0a1d
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 7 deletions.
3 changes: 3 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ output = peft_model.generate(**inputs, adapter_names=adapter_names, max_new_toke

Note that the order does not matter here, i.e. the samples in the batch don't need to be grouped by adapter as in the example above. We just need to ensure that the `adapter_names` argument is aligned correctly with the samples.

Additionally, the same approach also works with the `modules_to_save` feature, which allows for saving and reusing specific neural network layers, such as custom heads for classification tasks, across different LoRA adapters.

### Caveats

Using this features has some drawbacks, namely:
Expand All @@ -378,6 +380,7 @@ Using this features has some drawbacks, namely:
- You cannot pass `adapter_names` when some adapter weights where merged with base weight using the `merge_adapter` method. Please unmerge all adapters first by calling `model.unmerge_adapter()`.
- For obvious reasons, this cannot be used after calling `merge_and_unload()`, since all the LoRA adapters will be merged into the base weights in this case.
- This feature does not currently work with DoRA, so set `use_dora=False` in your `LoraConfig` if you want to use it.
- The `modules_to_save` feature is currently only supported for the layers of types `Linear`, `Embedding`, `Conv2d` and `Conv1d`.
- There is an expected overhead for inference with `adapter_names`, especially if the amount of different adapters in the batch is high. This is because the batch size is effectively reduced to the number of samples per adapter. If runtime performance is your top priority, try the following:
- Increase the batch size.
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handfull of different adapters.
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
Expand Down
64 changes: 59 additions & 5 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import inspect
import os
import warnings
from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Any, Optional

import accelerate
import torch
Expand Down Expand Up @@ -268,10 +270,62 @@ def _create_new_hook(self, old_hook):
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook

def forward(self, *args, **kwargs):
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
adapter_names = kwargs.get("adapter_names", None)
if adapter_names is None:
return

if len(x) != len(adapter_names):
msg = (
"Length of `adapter_names` should be the same as the number of inputs, but got "
f"{len(adapter_names)} and {len(x)} respectively."
)
raise ValueError(msg)

def _mixed_batch_forward(
self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.

SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d)

module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES])

if not isinstance(self.original_module, SUPPORTED_MODULES):
raise TypeError(f"Mixed batching is only supported for the following modules: {module_names}.")

unique_adapters = set(adapter_names)
sub_batch_indices_list = []

for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

results = [0 for _ in range(len(input))]

for i, active_adapter in enumerate(unique_adapters):
sub_batch = input[sub_batch_indices_list[i]]

if active_adapter == "__base__":
output = self.original_module(sub_batch, *args, **kwargs)
else:
output = self.modules_to_save[active_adapter](sub_batch, *args, **kwargs)

for index, j in enumerate(sub_batch_indices_list[i]):
results[j] = output[index]

return torch.stack(results)

def forward(self, x: torch.Tensor, *args, **kwargs):
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)
return self.original_module(x, *args, **kwargs)
if adapter_names is None:
return self.modules_to_save[self.active_adapter](x, *args, **kwargs)
return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)

def enable_adapters(self, enabled: bool):
"""Toggle the enabling and disabling of adapters
Expand Down Expand Up @@ -546,7 +600,7 @@ def get_auto_gptq_quant_linear(gptq_quantization_config):
return None


def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
"""
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
Expand Down
55 changes: 54 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,29 @@ def forward(self, X):
return X


class MLPWithGRU(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.gru = nn.GRU(input_size=20, hidden_size=20, num_layers=1, batch_first=True, bias=bias)
self.fc = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = X.unsqueeze(1)
X, _ = self.gru(X)
X = X.squeeze(1)
X = self.fc(X)
X = self.sm(X)
return X


class MLP_LayerNorm(nn.Module):
def __init__(self, bias=True):
super().__init__()
Expand Down Expand Up @@ -3326,15 +3349,36 @@ def test_mixed_adapter_batches_lora_mlp(self, mlp_lora):

def test_mixed_adapter_batches_lora_different_target_layers(self, mlp_lora):
base_model = MLP().to(self.torch_device).eval()
# target different lora layers
config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, mlp_lora):
base_model = MLP().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, mlp_lora):
base_model = MLPWithGRU().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv1d)
module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES])
with pytest.raises(
TypeError, match=f"Mixed batching is only supported for the following modules: {module_names}."
):
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, mlp_lora):
base_model = MLP().to(self.torch_device).eval()
# target different lora layers
Expand All @@ -3356,6 +3400,15 @@ def test_mixed_adapter_batches_lora_conv1d_emb(self):
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_conv1d_emb_multiple_modules_to_save(self):
base_model = ModelEmbConv1D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)

def test_mixed_adapter_batches_lora_conv2d(self):
base_model = ModelConv2D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["conv2d"], init_lora_weights=False)
Expand Down

0 comments on commit adf0a1d

Please sign in to comment.