Skip to content

Commit

Permalink
Address reviewer feedback
Browse files Browse the repository at this point in the history
- make special_peft_forward_args an instance attribute
- simplify loop

Also:

- make test asserts more terse
  • Loading branch information
BenjaminBossan committed Mar 15, 2024
1 parent 195590b commit 4812ea4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 31 deletions.
33 changes: 13 additions & 20 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name
self.modules_to_save = None
self.active_adapter = adapter_name
self.peft_type = peft_config.peft_type
# These args are special PEFT arguments that users can pass. They need to be removed before passing them to
# forward.
self.special_peft_forward_args = {"adapter_names"}

self._is_prompt_learning = peft_config.is_prompt_learning
if self._is_prompt_learning:
Expand Down Expand Up @@ -555,14 +558,12 @@ def forward(self, *args: Any, **kwargs: Any):
Forward pass of the model.
"""
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.get_base_model()(*args, **kwargs)

def generate(self, *args, **kwargs):
with self._enable_peft_forward_hooks(*args, **kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.get_base_model().generate(*args, **kwargs)

def _get_base_model_class(self, is_prompt_tuning=False):
Expand Down Expand Up @@ -925,10 +926,9 @@ def forward(
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
peft_config = self.active_peft_config
special_peft_args = {"adapter_names"}
if not peft_config.is_prompt_learning:
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
Expand Down Expand Up @@ -1122,8 +1122,7 @@ def forward(
kwargs["task_ids"] = task_ids

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1183,9 +1182,8 @@ def generate(self, *args, **kwargs):
self.base_model.generation_config = self.generation_config
try:
if not peft_config.is_prompt_learning:
special_peft_args = {"adapter_names"}
with self._enable_peft_forward_hooks(*args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
outputs = self.base_model.generate(*args, **kwargs)
else:
outputs = self.base_model.generate(**kwargs)
Expand Down Expand Up @@ -1321,8 +1319,7 @@ def forward(
kwargs["task_ids"] = task_ids

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1436,9 +1433,8 @@ def generate(self, **kwargs):
)
try:
if not peft_config.is_prompt_learning:
special_peft_args = {"adapter_names"}
with self._enable_peft_forward_hooks(**kwargs):
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
outputs = self.base_model.generate(**kwargs)
else:
if "input_ids" not in kwargs:
Expand Down Expand Up @@ -1585,8 +1581,7 @@ def forward(

if not peft_config.is_prompt_learning:
with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
if peft_config.peft_type == PeftType.POLY:
kwargs["task_ids"] = task_ids
return self.base_model(
Expand Down Expand Up @@ -1767,8 +1762,7 @@ def forward(
kwargs["task_ids"] = task_ids

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1945,8 +1939,7 @@ def forward(
kwargs["task_ids"] = task_ids

with self._enable_peft_forward_hooks(**kwargs):
special_peft_args = {"adapter_names"}
kwargs = {k: v for k, v in kwargs.items() if k not in special_peft_args}
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
return self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
8 changes: 3 additions & 5 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,10 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")

hook_handles = []
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
for module in self.modules():
if isinstance(module, LoraLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = target.register_forward_pre_hook(pre_forward, with_kwargs=True)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

yield
Expand Down
8 changes: 2 additions & 6 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,8 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):

atol, rtol = 1e-4, 1e-4
# sanity check that there are enough outputs and that they are different
assert len(output_base) >= 3
assert len(output_adapter0) >= 3
assert len(output_adapter1) >= 3
assert len(logits_base) >= 3
assert len(logits_adapter0) >= 3
assert len(logits_adapter1) >= 3
assert len(output_base) == len(output_adapter0) == len(output_adapter1) >= 3
assert len(logits_base) == len(logits_adapter0) == len(logits_adapter1) >= 3
assert not torch.allclose(output_base, output_adapter0, atol=atol, rtol=rtol)
assert not torch.allclose(output_base, output_adapter1, atol=atol, rtol=rtol)
assert not torch.allclose(output_adapter0, output_adapter1, atol=atol, rtol=rtol)
Expand Down

0 comments on commit 4812ea4

Please sign in to comment.