Skip to content

Commit

Permalink
[PEFT] Final fixes (#26559)
Browse files Browse the repository at this point in the history
* fix issues with PEFT

* logger warning futurewarning issues

* fixup

* adapt from suggestions

* oops

* rm test
  • Loading branch information
younesbelkada authored Oct 3, 2023
1 parent ae9a344 commit 2aef9a9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
9 changes: 7 additions & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import (
Expand Down Expand Up @@ -159,6 +160,10 @@ def load_adapter(
"The one in `adapter_kwargs` will be used."
)

# Override token with adapter_kwargs' token
if "token" in adapter_kwargs:
token = adapter_kwargs.pop("token")

if peft_config is None:
adapter_config_file = find_adapter_config_file(
peft_model_id,
Expand Down Expand Up @@ -381,8 +386,8 @@ def active_adapters(self) -> List[str]:
return active_adapters

def active_adapter(self) -> str:
logger.warning(
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning
warnings.warn(
"The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
)

return self.active_adapters()[0]
Expand Down
18 changes: 11 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,15 +1933,21 @@ def save_pretrained(
if token is not None:
kwargs["token"] = token

_hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)

# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False) and getattr(self, "is_8bit_serializable", False):
warnings.warn(
if (
getattr(self, "is_loaded_in_8bit", False)
and not getattr(self, "is_8bit_serializable", False)
and not _hf_peft_config_loaded
):
raise ValueError(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.",
UserWarning,
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
)

if getattr(self, "is_loaded_in_4bit", False):
# If the model has adapters attached, you can save the adapters
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded:
raise NotImplementedError(
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
)
Expand Down Expand Up @@ -1982,8 +1988,6 @@ def save_pretrained(
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)

_hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False)

# Save the config
if is_main_process:
if not _hf_peft_config_loaded:
Expand Down
36 changes: 36 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,42 @@ def test_peft_from_pretrained_kwargs(self):
# dummy generation
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))

@require_torch_gpu
def test_peft_save_quantized(self):
"""
Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models
"""
# 4bit
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")

module = peft_model.model.decoder.layers[0].self_attn.v_proj
self.assertTrue(module.__class__.__name__ == "Linear4bit")
self.assertTrue(peft_model.hf_device_map is not None)

with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))

# 8-bit
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")

module = peft_model.model.decoder.layers[0].self_attn.v_proj
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
self.assertTrue(peft_model.hf_device_map is not None)

with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)

self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))

def test_peft_pipeline(self):
"""
Simple test that tests the basic usage of PEFT model + pipeline
Expand Down
7 changes: 0 additions & 7 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,6 @@ def test_generate_quality_config(self):

self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)

def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error
Expand Down

0 comments on commit 2aef9a9

Please sign in to comment.